@@ -49,13 +49,32 @@ function gather_kernel!(dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, max
49
49
return nothing
50
50
end
51
51
52
+ function checkbounds_src (src, dims:: Int , :: Type{<:Any} )
53
+ return i -> checkbounds (Bool, src, ntuple (x -> Colon (), dims)... , i... )
54
+ end
55
+
56
+ function checkbounds_src (src, dims:: Int , :: Type{<:CartesianIndex} )
57
+ return i -> checkbounds (Bool, src, ntuple (x -> Colon (), dims)... , i)
58
+ end
59
+
52
60
function NNlib. gather! (dst:: AnyCuArray , src:: AnyCuArray , idx:: AnyCuArray )
61
+ # check dims
53
62
dims = gather_check_dims (src, dst, idx)
54
63
dims_size = size (src)[1 : dims]
55
64
max_dims_idx = prod (dims_size)
56
65
max_idx = max_dims_idx * length (idx)
57
- args = dst, src, idx, max_idx, max_dims_idx, dims_size
58
66
67
+ # check bounds
68
+ chk = checkbounds_src (src, dims, eltype (idx))
69
+ in_bnd = map (chk, collect (idx))
70
+ if ! all (in_bnd)
71
+ j = findfirst (i -> ! i, in_bnd)
72
+ k = CUDA. @allowscalar idx[j]
73
+ throw (BoundsError (src, k))
74
+ end
75
+
76
+ # cuda kernel
77
+ args = dst, src, idx, max_idx, max_dims_idx, dims_size
59
78
kernel = @cuda launch= false gather_kernel! (args... )
60
79
config = launch_configuration (kernel. fun; max_threads= 256 )
61
80
threads = min (max_idx, config. threads)
0 commit comments