Skip to content

Commit f3343a1

Browse files
committed
copyto method using Adapt move to CPU
1 parent 0a5350c commit f3343a1

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

src/array.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,14 @@ Base.getindex(x::OneHotArray{<:Any, N}, ::Colon, ::Vararg{Colon, N}) where N = x
7979
Base.similar(x::OneHotArray{<:Any,<:Any,<:Any,<:AbstractArray}, ::Type{T}, size::Base.Dims) where T =
8080
similar(x.indices, T, size)
8181

82-
function Base.copyto!(dst::AbstractArray{T,N}, src::OneHotArray{<:Any,Nm1,N,<:AbstractArray}) where {T,N,Nm1}
82+
function Base.copyto!(dst::AbstractArray{T,N}, src::OneHotArray{<:Any,<:Any,N,<:AbstractArray}) where {T,N}
8383
size(dst) == size(src) || return invoke(copyto!, Tuple{typeof(dst), AbstractArray{Bool,N}})
84-
# fill!(dst, false)
85-
# setindex!.(eachslice(dst; dims=ntuple(d->d+1, Nm1)), true, src.indices)
86-
# setindex!.(Ref(dst), true, src.indices, axes(src.indices)...)
87-
dst .= reshape(src.indices, 1, size(src.indices)...) .== (1:src.nlabels) # this works at REPL!
84+
dst .= reshape(src.indices, 1, size(src.indices)...) .== (1:src.nlabels)
8885
return dst
8986
end
87+
function Base.copyto!(dst::Array{T,N}, src::OneHotArray{<:Any,<:Any,N,<:AnyGPUArray}) where {T,N}
88+
copyto!(dst, adapt(Array, src))
89+
end
9090

9191
function Base.showarg(io::IO, x::OneHotArray, toplevel)
9292
print(io, ndims(x) == 1 ? "OneHotVector(" : ndims(x) == 2 ? "OneHotMatrix(" : "OneHotArray(")

test/gpu.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,11 @@ end
3737

3838
# These were broken on OneHotArrays v0.2.7
3939
@test @allowscalar cx[2,2] == x[2,2]
40-
@test_broken collect(cx) == collect(x)
41-
@test_broken Matrix(cx) == Matrix(x) == collect(x)
42-
@test_broken Array{Float32}(cx) == Array{Float32}(x) == collect(x)
40+
@test collect(cx) == collect(x)
41+
@test Matrix(cx) == Matrix(x) == collect(x)
42+
@test Array{Float32}(cx) == Array{Float32}(x) == collect(x)
4343
@test convert(AbstractArray{Float32}, cx) isa CuArray{Float32}
44+
@test collect(convert(AbstractArray{Float32}, cx)) == collect(x)
4445
end
4546

4647
@testset "onehot gpu" begin

0 commit comments

Comments
 (0)