From f1bd6e248b20f2ad702f3a86e525869f0df3e9fb Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Fri, 3 Jun 2022 22:53:16 +0800 Subject: [PATCH 1/5] add checkbounds for gather --- src/gather.jl | 21 ++++++++++++++++++++- test/gather.jl | 2 ++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/gather.jl b/src/gather.jl index 5690dfc..ed6cc20 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -49,13 +49,32 @@ function gather_kernel!(dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, max return nothing end +function checkbounds_src(src, dims::Int, ::Type{<:Any}) + return i -> checkbounds(Bool, src, ntuple(x -> Colon(), dims)..., i...) +end + +function checkbounds_src(src, dims::Int, ::Type{<:CartesianIndex}) + return i -> checkbounds(Bool, src, ntuple(x -> Colon(), dims)..., i) +end + function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) + # check dims dims = gather_check_dims(src, dst, idx) dims_size = size(src)[1:dims] max_dims_idx = prod(dims_size) max_idx = max_dims_idx * length(idx) - args = dst, src, idx, max_idx, max_dims_idx, dims_size + # check bounds + chk = checkbounds_src(src, dims, eltype(idx)) + in_bnd = map(chk, collect(idx)) + if !all(in_bnd) + j = findfirst(i -> !i, in_bnd) + k = CUDA.@allowscalar idx[j] + throw(BoundsError(src, k)) + end + + # cuda kernel + args = dst, src, idx, max_idx, max_dims_idx, dims_size kernel = @cuda launch=false gather_kernel!(args...) config = launch_configuration(kernel.fun; max_threads=256) threads = min(max_idx, config.threads) diff --git a/test/gather.jl b/test/gather.jl index f200f77..5bc0213 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -17,6 +17,8 @@ gputest(src -> NNlib.gather(src, index), src, checkgrad=true) @test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output @test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index) + index[1,:] .= 6 + @test_throws BoundsError NNlib.gather(src, index) ## 1d src, 2d index of tuples -> 2d output src = CT([3, 4, 5, 6, 7]) From e58131a23c5f4c04ca254245c5f1e4455dbf3491 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Sat, 4 Jun 2022 00:28:41 +0800 Subject: [PATCH 2/5] support empty source array for gather fix --- src/gather.jl | 1 + test/gather.jl | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/src/gather.jl b/src/gather.jl index ed6cc20..5ad39b6 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -66,6 +66,7 @@ function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) # check bounds chk = checkbounds_src(src, dims, eltype(idx)) + isempty(src) && return dst in_bnd = map(chk, collect(idx)) if !all(in_bnd) j = findfirst(i -> !i, in_bnd) diff --git a/test/gather.jl b/test/gather.jl index 5bc0213..4252108 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -91,4 +91,14 @@ @test y isa CuArray{Float32,3} @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) gputest(src -> NNlib.gather(src, index), src, checkgrad=true) + + ## empty 2d src, 2d index of ints -> 3d output + src = CT(zeros(Int, 0, 3)) + index = cu([1 2 3; + 2 2 1; + 3 1 3]) + + y = NNlib.gather(src, index) + @test y isa CuArray{Float32,3} + @test size(y) == (0, size(index)...) end From 1f0ad39eb4c238edf3ab3a151bc5b3380af93cf0 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Sat, 4 Jun 2022 13:58:52 +0800 Subject: [PATCH 3/5] fix --- src/gather.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gather.jl b/src/gather.jl index 5ad39b6..7fb0069 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -66,7 +66,6 @@ function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) # check bounds chk = checkbounds_src(src, dims, eltype(idx)) - isempty(src) && return dst in_bnd = map(chk, collect(idx)) if !all(in_bnd) j = findfirst(i -> !i, in_bnd) @@ -74,6 +73,9 @@ function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) throw(BoundsError(src, k)) end + # empty array input + isempty(src) && return dst + # cuda kernel args = dst, src, idx, max_idx, max_dims_idx, dims_size kernel = @cuda launch=false gather_kernel!(args...) From 0e345cb33f1fcf866265e69136b78d14fc342d0d Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Sat, 4 Jun 2022 14:02:42 +0800 Subject: [PATCH 4/5] Update src/gather.jl Co-authored-by: Peter Update src/gather.jl Co-authored-by: Peter Update src/gather.jl Co-authored-by: Peter Update src/gather.jl Co-authored-by: Peter --- src/gather.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/gather.jl b/src/gather.jl index 7fb0069..b506388 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -49,11 +49,11 @@ function gather_kernel!(dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, max return nothing end -function checkbounds_src(src, dims::Int, ::Type{<:Any}) +function checkbounds_src(src, dims::Union{Int, Val}, ::Type{<:Any}) return i -> checkbounds(Bool, src, ntuple(x -> Colon(), dims)..., i...) end -function checkbounds_src(src, dims::Int, ::Type{<:CartesianIndex}) +function checkbounds_src(src, dims::Union{Int, Val}, ::Type{<:CartesianIndex}) return i -> checkbounds(Bool, src, ntuple(x -> Colon(), dims)..., i) end @@ -65,10 +65,9 @@ function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) max_idx = max_dims_idx * length(idx) # check bounds - chk = checkbounds_src(src, dims, eltype(idx)) - in_bnd = map(chk, collect(idx)) + in_bnd = map(checkbounds_src(src, Val(dims), eltype(idx)), idx) if !all(in_bnd) - j = findfirst(i -> !i, in_bnd) + j = findfirst(!, in_bnd) k = CUDA.@allowscalar idx[j] throw(BoundsError(src, k)) end From d3e4bec2f8f9dfee406640f271b54313c3f0bb34 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Thu, 16 Jun 2022 10:44:08 +0800 Subject: [PATCH 5/5] resolve closure --- Project.toml | 1 + src/NNlibCUDA.jl | 2 +- src/gather.jl | 18 ++++++++++-------- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index cafb011..4f0ad30 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" version = "0.2.3" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" diff --git a/src/NNlibCUDA.jl b/src/NNlibCUDA.jl index 3b91d42..f02da79 100644 --- a/src/NNlibCUDA.jl +++ b/src/NNlibCUDA.jl @@ -1,7 +1,7 @@ module NNlibCUDA using NNlib -using CUDA +using CUDA, Adapt using Random, Statistics const IntOrIntTuple = Union{Integer, NTuple{N,<:Integer} where N} diff --git a/src/gather.jl b/src/gather.jl index b506388..e79e1b4 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -49,13 +49,15 @@ function gather_kernel!(dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, max return nothing end -function checkbounds_src(src, dims::Union{Int, Val}, ::Type{<:Any}) - return i -> checkbounds(Bool, src, ntuple(x -> Colon(), dims)..., i...) +struct BoundInfo{T,D} + A::T + dims::D end -function checkbounds_src(src, dims::Union{Int, Val}, ::Type{<:CartesianIndex}) - return i -> checkbounds(Bool, src, ntuple(x -> Colon(), dims)..., i) -end +Adapt.adapt_structure(to, x::BoundInfo) = BoundInfo(adapt(to, parent(x.A)), x.dims) + +(b::BoundInfo)(i) = checkbounds(Bool, b.A, ntuple(x -> Colon(), b.dims)..., i...) +(b::BoundInfo)(i::CartesianIndex) = checkbounds(Bool, b.A, ntuple(x -> Colon(), b.dims)..., i) function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) # check dims @@ -65,9 +67,9 @@ function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) max_idx = max_dims_idx * length(idx) # check bounds - in_bnd = map(checkbounds_src(src, Val(dims), eltype(idx)), idx) - if !all(in_bnd) - j = findfirst(!, in_bnd) + in_bnd = mapreduce(BoundInfo(src, Val(dims)), &, idx) + if !in_bnd + j = findfirst(!, map(BoundInfo(src, Val(dims)), idx)) k = CUDA.@allowscalar idx[j] throw(BoundsError(src, k)) end