Skip to content

Commit e418a4d

Browse files
Merge pull request #33 from yuehhua/fix
Support gpu scatter/gather with CartesianIndex
2 parents 210c1ea + a6a6155 commit e418a4d

File tree

4 files changed

+110
-0
lines changed

4 files changed

+110
-0
lines changed

ext/NNlibCUDA/src/gather.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,18 @@ function gather_kernel!(dst, src, idx, max_idx, max_dims_idx, dims_size)
3737
return nothing
3838
end
3939

40+
function gather_kernel!(dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, max_idx, max_dims_idx, dims_size)
41+
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
42+
43+
@inbounds if index <= max_idx
44+
j, k = divrem(index-1, max_dims_idx)
45+
dims_i = CartesianIndices(dims_size)[k+1]
46+
li = Base._to_linear_index(src, Tuple(dims_i)..., Tuple(idx[j+1])...)
47+
dst[index] = src[li]
48+
end
49+
return nothing
50+
end
51+
4052
function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray)
4153
dims = gather_check_dims(src, dst, idx)
4254
dims_size = size(src)[1:dims]

ext/NNlibCUDA/src/scatter.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,16 @@ function scatter_kernel!(op, dst, src, idx)
99
return nothing
1010
end
1111

12+
function scatter_kernel!(op, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex})
13+
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
14+
15+
@inbounds if index <= length(idx)
16+
li = Base._to_linear_index(dst, Tuple(idx[index])...)
17+
CUDA.@atomic dst[li] = op(dst[li], src[index])
18+
end
19+
return nothing
20+
end
21+
1222
function scatter_kernel!(op, dst, src, idx, max_idx, max_dims_idx, dims_size)
1323
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
1424

@@ -20,6 +30,18 @@ function scatter_kernel!(op, dst, src, idx, max_idx, max_dims_idx, dims_size)
2030
return nothing
2131
end
2232

33+
function scatter_kernel!(op, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, max_idx, max_dims_idx, dims_size)
34+
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
35+
36+
@inbounds if index <= max_idx
37+
j, k = divrem(index-1, max_dims_idx)
38+
dims_i = CartesianIndices(dims_size)[k+1]
39+
li = Base._to_linear_index(dst, Tuple(dims_i)..., Tuple(idx[j+1])...)
40+
CUDA.@atomic dst[li] = op(dst[li], src[index])
41+
end
42+
return nothing
43+
end
44+
2345
function NNlib.scatter!(op, dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray)
2446
dims = NNlib.scatter_dims(dst, src, idx)
2547
args = if dims == 0
@@ -69,6 +91,25 @@ function ∇scatter_src_kernel!(op, Δsrc, src, idx, rev_idx, max_idx, T)
6991
return nothing
7092
end
7193

94+
function ∇scatter_src_kernel!(op, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, rev_idx, max_idx, T)
95+
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
96+
97+
@inbounds if index <= max_idx
98+
cart_j = CartesianIndices(idx)[index]
99+
# get aggregating indeices, which is to be aggregated together, and itself index
100+
inds = rev_idx[Tuple(idx[cart_j])...]
101+
# multiply all values to be aggregated but not itself
102+
x = one(T)
103+
for k in inds
104+
x *= src[k]
105+
end
106+
x /= src[cart_j]
107+
# apply `op` on `Δsrc[i, k]` and `x`
108+
Δsrc[cart_j] = op(Δsrc[cart_j], x)
109+
end
110+
return nothing
111+
end
112+
72113
function ∇scatter_src_kernel!(op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, T)
73114
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
74115

@@ -91,6 +132,28 @@ function ∇scatter_src_kernel!(op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_
91132
return nothing
92133
end
93134

135+
function ∇scatter_src_kernel!(op, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, rev_idx, pre_cart_idx, max_dims_idx, max_idx, T)
136+
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
137+
138+
@inbounds if index <= max_idx
139+
i, j = fldmod1(index, max_dims_idx)
140+
cart_i = CartesianIndices(idx)[i]
141+
cart_j = pre_cart_idx[j]
142+
# get aggregating indeices, which is to be aggregated together, and itself index
143+
inds = rev_idx[Tuple(idx[cart_i])...]
144+
# multiply all values to be aggregated but not itself
145+
x = one(T)
146+
for k in inds
147+
jk = Base._to_linear_index(src, Tuple(cart_j)..., Tuple(k)...)
148+
x *= src[jk]
149+
end
150+
x /= src[index]
151+
# apply `op` on `Δsrc[i, k]` and `x`
152+
Δsrc[index] = op(Δsrc[index], x)
153+
end
154+
return nothing
155+
end
156+
94157
function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst,
95158
src::AnyCuArray{Tsrc,Nsrc},
96159
idx::AnyCuArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}

ext/NNlibCUDA/test/gather.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,38 @@
1717
gputest(src -> NNlib.gather(src, index), src, checkgrad=true)
1818
@test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output
1919
@test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index)
20+
21+
## 1d src, 2d index of tuples -> 2d output
22+
src = CT([3, 4, 5, 6, 7])
23+
index = cu([(1,) (2,) (3,) (4,);
24+
(4,) (2,) (1,) (3,);
25+
(3,) (5,) (5,) (3,)])
26+
output = CT([3 4 5 6;
27+
6 4 3 5;
28+
5 7 7 5])
29+
30+
y = NNlib.gather(src, index)
31+
@test y isa CuArray{Float32,2}
32+
@test size(y) == size(index)
33+
gputest(src -> NNlib.gather(src, index), src, checkgrad=true)
34+
@test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output
35+
@test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index)
36+
37+
## 1d src, 2d index of CartesianIndex -> 2d output
38+
src = CT([3, 4, 5, 6, 7])
39+
index = cu(CartesianIndex.([(1,) (2,) (3,) (4,);
40+
(4,) (2,) (1,) (3,);
41+
(3,) (5,) (5,) (3,)]))
42+
output = CT([3 4 5 6;
43+
6 4 3 5;
44+
5 7 7 5])
45+
46+
y = NNlib.gather(src, index)
47+
@test y isa CuArray{Float32,2}
48+
@test size(y) == size(index)
49+
gputest(src -> NNlib.gather(src, index), src, checkgrad=true)
50+
@test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output
51+
@test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index)
2052

2153
## 1d src, 3d index of ints -> 3d output
2254
src = CT([3, 4, 5, 6, 7])

ext/NNlibCUDA/test/scatter.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ idxs = [
1616
cu([(1,) (2,) (3,) (4,);
1717
(4,) (2,) (1,) (3,);
1818
(3,) (5,) (5,) (3,)]), # tuple index
19+
cu(CartesianIndex.([(1,) (2,) (3,) (4,);
20+
(4,) (2,) (1,) (3,);
21+
(3,) (5,) (5,) (3,)])), # CartesianIndex index
1922
]
2023

2124
types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}]

0 commit comments

Comments
 (0)