Skip to content

Commit c81deb7

Browse files
committed
support empty source array for gather
1 parent f1bd6e2 commit c81deb7

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

src/gather.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray)
6666

6767
# check bounds
6868
chk = checkbounds_src(src, dims, eltype(idx))
69+
isempty(src) && return dst
6970
in_bnd = map(chk, collect(idx))
7071
if !all(in_bnd)
7172
j = findfirst(i -> !i, in_bnd)

test/gather.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,4 +91,14 @@
9191
@test y isa CuArray{Float32,3}
9292
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
9393
gputest(src -> NNlib.gather(src, index), src, checkgrad=true)
94+
95+
## empty 2d src, 2d index of ints -> 3d output
96+
src = CT(zeros(Int, 0, 3))
97+
index = cu([1 2 3;
98+
2 2 1;
99+
3 1 3])
100+
101+
y = NNlib.gather(src, index)
102+
@test y isa CuArray{Float32,2}
103+
@test size(y) == (0, 3, 3)
94104
end

0 commit comments

Comments
 (0)