diff --git a/Project.toml b/Project.toml index cafb011..4f0ad30 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" version = "0.2.3" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" diff --git a/src/NNlibCUDA.jl b/src/NNlibCUDA.jl index 3b91d42..f02da79 100644 --- a/src/NNlibCUDA.jl +++ b/src/NNlibCUDA.jl @@ -1,7 +1,7 @@ module NNlibCUDA using NNlib -using CUDA +using CUDA, Adapt using Random, Statistics const IntOrIntTuple = Union{Integer, NTuple{N,<:Integer} where N} diff --git a/src/gather.jl b/src/gather.jl index 5690dfc..e79e1b4 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -49,13 +49,36 @@ function gather_kernel!(dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, max return nothing end +struct BoundInfo{T,D} + A::T + dims::D +end + +Adapt.adapt_structure(to, x::BoundInfo) = BoundInfo(adapt(to, parent(x.A)), x.dims) + +(b::BoundInfo)(i) = checkbounds(Bool, b.A, ntuple(x -> Colon(), b.dims)..., i...) +(b::BoundInfo)(i::CartesianIndex) = checkbounds(Bool, b.A, ntuple(x -> Colon(), b.dims)..., i) + function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) + # check dims dims = gather_check_dims(src, dst, idx) dims_size = size(src)[1:dims] max_dims_idx = prod(dims_size) max_idx = max_dims_idx * length(idx) - args = dst, src, idx, max_idx, max_dims_idx, dims_size + # check bounds + in_bnd = mapreduce(BoundInfo(src, Val(dims)), &, idx) + if !in_bnd + j = findfirst(!, map(BoundInfo(src, Val(dims)), idx)) + k = CUDA.@allowscalar idx[j] + throw(BoundsError(src, k)) + end + + # empty array input + isempty(src) && return dst + + # cuda kernel + args = dst, src, idx, max_idx, max_dims_idx, dims_size kernel = @cuda launch=false gather_kernel!(args...) config = launch_configuration(kernel.fun; max_threads=256) threads = min(max_idx, config.threads) diff --git a/test/gather.jl b/test/gather.jl index f200f77..4252108 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -17,6 +17,8 @@ gputest(src -> NNlib.gather(src, index), src, checkgrad=true) @test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output @test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index) + index[1,:] .= 6 + @test_throws BoundsError NNlib.gather(src, index) ## 1d src, 2d index of tuples -> 2d output src = CT([3, 4, 5, 6, 7]) @@ -89,4 +91,14 @@ @test y isa CuArray{Float32,3} @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) gputest(src -> NNlib.gather(src, index), src, checkgrad=true) + + ## empty 2d src, 2d index of ints -> 3d output + src = CT(zeros(Int, 0, 3)) + index = cu([1 2 3; + 2 2 1; + 3 1 3]) + + y = NNlib.gather(src, index) + @test y isa CuArray{Float32,3} + @test size(y) == (0, size(index)...) end