Skip to content

Commit ad9fda2

Browse files
Update src/gather.jl
Co-authored-by: Peter <[email protected]>
1 parent f927bf8 commit ad9fda2

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

src/gather.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,7 @@ function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray)
6565
max_idx = max_dims_idx * length(idx)
6666

6767
# check bounds
68-
chk = checkbounds_src(src, dims, eltype(idx))
69-
in_bnd = map(chk, collect(idx))
68+
in_bnd = map(checkbounds_src(src, Val(dims), eltype(idx)), idx)
7069
if !all(in_bnd)
7170
j = findfirst(!, in_bnd)
7271
k = CUDA.@allowscalar idx[j]

0 commit comments

Comments
 (0)