Skip to content

Commit 4be74ad

Browse files
committed
gather throw error on cpu and gpu array inputs
1 parent da73f07 commit 4be74ad

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

src/gather.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,17 @@ function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray)
6363
kernel(args...; threads=threads, blocks=blocks)
6464
return dst
6565
end
66+
67+
function NNlib.gather(src::AnyCuArray{Tsrc, Nsrc},
68+
idx::AnyCuArray{Tidx, Nidx}) where
69+
{Tsrc, Nsrc, Nidx, Tidx}
70+
M = NNlib.typelength(Tidx)
71+
dstsize = (size(src)[1:Nsrc-M]..., size(idx)...)
72+
dst = similar(src, Tsrc, dstsize)
73+
return NNlib.gather!(dst, src, idx)
74+
end
75+
76+
function NNlib.gather(src::AnyCuArray, idx::AbstractArray)
77+
err_msg = "src and idx both must be on GPU, but received $(typeof(src)) and $(typeof(idx)), respectively."
78+
throw(ArgumentError(err_msg))
79+
end

test/gather.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
gputest(src -> NNlib.gather(src, index), src, checkgrad=true)
1818
@test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output
1919
@test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index)
20+
@test_throws ArgumentError NNlib.gather(src, collect(index))
2021

2122
## 1d src, 2d index of tuples -> 2d output
2223
src = CT([3, 4, 5, 6, 7])

0 commit comments

Comments
 (0)