@@ -79,14 +79,14 @@ Base.getindex(x::OneHotArray{<:Any, N}, ::Colon, ::Vararg{Colon, N}) where N = x
7979Base. similar (x:: OneHotArray{<:Any,<:Any,<:Any,<:AbstractArray} , :: Type{T} , size:: Base.Dims ) where T =
8080 similar (x. indices, T, size)
8181
82- function Base. copyto! (dst:: AbstractArray{T,N} , src:: OneHotArray{<:Any,Nm1 ,N,<:AbstractArray} ) where {T,N,Nm1 }
82+ function Base. copyto! (dst:: AbstractArray{T,N} , src:: OneHotArray{<:Any,<:Any ,N,<:AbstractArray} ) where {T,N}
8383 size (dst) == size (src) || return invoke (copyto!, Tuple{typeof (dst), AbstractArray{Bool,N}})
84- # fill!(dst, false)
85- # setindex!.(eachslice(dst; dims=ntuple(d->d+1, Nm1)), true, src.indices)
86- # setindex!.(Ref(dst), true, src.indices, axes(src.indices)...)
87- dst .= reshape (src. indices, 1 , size (src. indices)... ) .== (1 : src. nlabels) # this works at REPL!
84+ dst .= reshape (src. indices, 1 , size (src. indices)... ) .== (1 : src. nlabels)
8885 return dst
8986end
87+ function Base. copyto! (dst:: Array{T,N} , src:: OneHotArray{<:Any,<:Any,N,<:AnyGPUArray} ) where {T,N}
88+ copyto! (dst, adapt (Array, src))
89+ end
9090
9191function Base. showarg (io:: IO , x:: OneHotArray , toplevel)
9292 print (io, ndims (x) == 1 ? " OneHotVector(" : ndims (x) == 2 ? " OneHotMatrix(" : " OneHotArray(" )
0 commit comments