Skip to content

Commit 25457f5

Browse files
Merge pull request #1959 from FluxML/cl/oh
onehotbatch with CuArray
2 parents 12bad50 + 96cc8bc commit 25457f5

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

src/onehot.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl
185185
"""
186186
onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...)
187187

188+
_onehotbatch(data::CuArray, labels) = _onehotbatch(data |> cpu, labels) |> gpu
189+
188190
function _onehotbatch(data, labels)
189191
indices = UInt32[something(_findval(i, labels), 0) for i in data]
190192
if 0 in indices

test/cuda/cuda.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ end
4545

4646
gA = rand(3, 2) |> gpu;
4747
@test gradient(A -> sum(A * y), gA)[1] isa CuArray
48+
49+
# construct from CuArray
50+
x = [1, 3, 2]
51+
y = Flux.onehotbatch(x, 0:3)
52+
y2 = Flux.onehotbatch(x |> gpu, 0:3)
53+
@test y2.indices isa CuArray
54+
@test y2 |> cpu == y
4855
end
4956

5057
@testset "onecold gpu" begin

0 commit comments

Comments
 (0)