@@ -9,6 +9,16 @@ function scatter_kernel!(op, dst, src, idx)
9
9
return nothing
10
10
end
11
11
12
+ function scatter_kernel! (op, dst, src, idx:: CUDA.CuDeviceArray{<:CartesianIndex} )
13
+ index = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
14
+
15
+ @inbounds if index <= length (idx)
16
+ li = Base. _to_linear_index (dst, Tuple (idx[index])... )
17
+ CUDA. @atomic dst[li] = op (dst[li], src[index])
18
+ end
19
+ return nothing
20
+ end
21
+
12
22
function scatter_kernel! (op, dst, src, idx, max_idx, max_dims_idx, dims_size)
13
23
index = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
14
24
@@ -20,6 +30,18 @@ function scatter_kernel!(op, dst, src, idx, max_idx, max_dims_idx, dims_size)
20
30
return nothing
21
31
end
22
32
33
+ function scatter_kernel! (op, dst, src, idx:: CUDA.CuDeviceArray{<:CartesianIndex} , max_idx, max_dims_idx, dims_size)
34
+ index = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
35
+
36
+ @inbounds if index <= max_idx
37
+ j, k = divrem (index- 1 , max_dims_idx)
38
+ dims_i = CartesianIndices (dims_size)[k+ 1 ]
39
+ li = Base. _to_linear_index (dst, Tuple (dims_i)... , Tuple (idx[j+ 1 ])... )
40
+ CUDA. @atomic dst[li] = op (dst[li], src[index])
41
+ end
42
+ return nothing
43
+ end
44
+
23
45
function NNlib. scatter! (op, dst:: AnyCuArray , src:: AnyCuArray , idx:: AnyCuArray )
24
46
dims = NNlib. scatter_dims (dst, src, idx)
25
47
args = if dims == 0
@@ -69,6 +91,25 @@ function ∇scatter_src_kernel!(op, Δsrc, src, idx, rev_idx, max_idx, T)
69
91
return nothing
70
92
end
71
93
94
+ function ∇scatter_src_kernel! (op, Δsrc, src, idx:: CUDA.CuDeviceArray{<:CartesianIndex} , rev_idx, max_idx, T)
95
+ index = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
96
+
97
+ @inbounds if index <= max_idx
98
+ cart_j = CartesianIndices (idx)[index]
99
+ # get aggregating indeices, which is to be aggregated together, and itself index
100
+ inds = rev_idx[Tuple (idx[cart_j])... ]
101
+ # multiply all values to be aggregated but not itself
102
+ x = one (T)
103
+ for k in inds
104
+ x *= src[k]
105
+ end
106
+ x /= src[cart_j]
107
+ # apply `op` on `Δsrc[i, k]` and `x`
108
+ Δsrc[cart_j] = op (Δsrc[cart_j], x)
109
+ end
110
+ return nothing
111
+ end
112
+
72
113
function ∇scatter_src_kernel! (op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, T)
73
114
index = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
74
115
@@ -91,6 +132,28 @@ function ∇scatter_src_kernel!(op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_
91
132
return nothing
92
133
end
93
134
135
+ function ∇scatter_src_kernel! (op, Δsrc, src, idx:: CUDA.CuDeviceArray{<:CartesianIndex} , rev_idx, pre_cart_idx, max_dims_idx, max_idx, T)
136
+ index = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
137
+
138
+ @inbounds if index <= max_idx
139
+ i, j = fldmod1 (index, max_dims_idx)
140
+ cart_i = CartesianIndices (idx)[i]
141
+ cart_j = pre_cart_idx[j]
142
+ # get aggregating indeices, which is to be aggregated together, and itself index
143
+ inds = rev_idx[Tuple (idx[cart_i])... ]
144
+ # multiply all values to be aggregated but not itself
145
+ x = one (T)
146
+ for k in inds
147
+ jk = Base. _to_linear_index (src, Tuple (cart_j)... , Tuple (k)... )
148
+ x *= src[jk]
149
+ end
150
+ x /= src[index]
151
+ # apply `op` on `Δsrc[i, k]` and `x`
152
+ Δsrc[index] = op (Δsrc[index], x)
153
+ end
154
+ return nothing
155
+ end
156
+
94
157
function NNlib. ∇scatter_src (op:: Union{typeof(*),typeof(/)} , Δ, dst,
95
158
src:: AnyCuArray{Tsrc,Nsrc} ,
96
159
idx:: AnyCuArray{Tidx,Nidx} ) where {Tsrc,Tidx,Nsrc,Nidx}
0 commit comments