@@ -49,11 +49,11 @@ function gather_kernel!(dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, max
49
49
return nothing
50
50
end
51
51
52
- function checkbounds_src (src, dims:: Int , :: Type{<:Any} )
52
+ function checkbounds_src (src, dims:: Union{ Int, Val} , :: Type{<:Any} )
53
53
return i -> checkbounds (Bool, src, ntuple (x -> Colon (), dims)... , i... )
54
54
end
55
55
56
- function checkbounds_src (src, dims:: Int , :: Type{<:CartesianIndex} )
56
+ function checkbounds_src (src, dims:: Union{ Int, Val} , :: Type{<:CartesianIndex} )
57
57
return i -> checkbounds (Bool, src, ntuple (x -> Colon (), dims)... , i)
58
58
end
59
59
@@ -65,10 +65,9 @@ function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray)
65
65
max_idx = max_dims_idx * length (idx)
66
66
67
67
# check bounds
68
- chk = checkbounds_src (src, dims, eltype (idx))
69
- in_bnd = map (chk, collect (idx))
68
+ in_bnd = map (checkbounds_src (src, Val (dims), eltype (idx)), idx)
70
69
if ! all (in_bnd)
71
- j = findfirst (i -> ! i , in_bnd)
70
+ j = findfirst (! , in_bnd)
72
71
k = CUDA. @allowscalar idx[j]
73
72
throw (BoundsError (src, k))
74
73
end
0 commit comments