Skip to content

Commit f1bd6e2

Browse files
committed
add checkbounds for gather
1 parent da73f07 commit f1bd6e2

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

src/gather.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,32 @@ function gather_kernel!(dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, max
4949
return nothing
5050
end
5151

52+
function checkbounds_src(src, dims::Int, ::Type{<:Any})
53+
return i -> checkbounds(Bool, src, ntuple(x -> Colon(), dims)..., i...)
54+
end
55+
56+
function checkbounds_src(src, dims::Int, ::Type{<:CartesianIndex})
57+
return i -> checkbounds(Bool, src, ntuple(x -> Colon(), dims)..., i)
58+
end
59+
5260
function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray)
61+
# check dims
5362
dims = gather_check_dims(src, dst, idx)
5463
dims_size = size(src)[1:dims]
5564
max_dims_idx = prod(dims_size)
5665
max_idx = max_dims_idx * length(idx)
57-
args = dst, src, idx, max_idx, max_dims_idx, dims_size
5866

67+
# check bounds
68+
chk = checkbounds_src(src, dims, eltype(idx))
69+
in_bnd = map(chk, collect(idx))
70+
if !all(in_bnd)
71+
j = findfirst(i -> !i, in_bnd)
72+
k = CUDA.@allowscalar idx[j]
73+
throw(BoundsError(src, k))
74+
end
75+
76+
# cuda kernel
77+
args = dst, src, idx, max_idx, max_dims_idx, dims_size
5978
kernel = @cuda launch=false gather_kernel!(args...)
6079
config = launch_configuration(kernel.fun; max_threads=256)
6180
threads = min(max_idx, config.threads)

test/gather.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
gputest(src -> NNlib.gather(src, index), src, checkgrad=true)
1818
@test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output
1919
@test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index)
20+
index[1,:] .= 6
21+
@test_throws BoundsError NNlib.gather(src, index)
2022

2123
## 1d src, 2d index of tuples -> 2d output
2224
src = CT([3, 4, 5, 6, 7])

0 commit comments

Comments
 (0)