Skip to content

Commit 72356fd

Browse files
committed
gather throw error on cpu and gpu array inputs
add tests
1 parent da73f07 commit 72356fd

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

src/gather.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,8 @@ function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray)
6363
kernel(args...; threads=threads, blocks=blocks)
6464
return dst
6565
end
66+
67+
function NNlib.gather(src::AnyCuArray, idx::AbstractArray)
68+
err_msg = "src and idx both must be on GPU, but received $(typeof(src)) and $(typeof(idx)), respectively."
69+
throw(ArgumentError(err_msg))
70+
end

test/gather.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
gputest(src -> NNlib.gather(src, index), src, checkgrad=true)
1818
@test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output
1919
@test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index)
20+
@test_throws ArgumentError NNlib.gather(src, collect(index))
2021

2122
## 1d src, 2d index of tuples -> 2d output
2223
src = CT([3, 4, 5, 6, 7])

0 commit comments

Comments
 (0)