Skip to content

Commit 0e345cb

Browse files
Update src/gather.jl
Co-authored-by: Peter <[email protected]> Update src/gather.jl Co-authored-by: Peter <[email protected]> Update src/gather.jl Co-authored-by: Peter <[email protected]> Update src/gather.jl Co-authored-by: Peter <[email protected]>
1 parent 1f0ad39 commit 0e345cb

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

src/gather.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ function gather_kernel!(dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, max
4949
return nothing
5050
end
5151

52-
function checkbounds_src(src, dims::Int, ::Type{<:Any})
52+
function checkbounds_src(src, dims::Union{Int, Val}, ::Type{<:Any})
5353
return i -> checkbounds(Bool, src, ntuple(x -> Colon(), dims)..., i...)
5454
end
5555

56-
function checkbounds_src(src, dims::Int, ::Type{<:CartesianIndex})
56+
function checkbounds_src(src, dims::Union{Int, Val}, ::Type{<:CartesianIndex})
5757
return i -> checkbounds(Bool, src, ntuple(x -> Colon(), dims)..., i)
5858
end
5959

@@ -65,10 +65,9 @@ function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray)
6565
max_idx = max_dims_idx * length(idx)
6666

6767
# check bounds
68-
chk = checkbounds_src(src, dims, eltype(idx))
69-
in_bnd = map(chk, collect(idx))
68+
in_bnd = map(checkbounds_src(src, Val(dims), eltype(idx)), idx)
7069
if !all(in_bnd)
71-
j = findfirst(i -> !i, in_bnd)
70+
j = findfirst(!, in_bnd)
7271
k = CUDA.@allowscalar idx[j]
7372
throw(BoundsError(src, k))
7473
end

0 commit comments

Comments
 (0)