1
1
# supported op: +, -, *, /, max, min, &, |, mean
2
2
3
- function scatter_kernel! (op, dst, src, idx)
3
+ function scatter_kernel! (op:: OP , dst, src, idx) where OP
4
4
index = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
5
5
6
6
@inbounds if index <= length (idx)
@@ -9,7 +9,7 @@ function scatter_kernel!(op, dst, src, idx)
9
9
return nothing
10
10
end
11
11
12
- function scatter_kernel! (op, dst, src, idx:: CUDA.CuDeviceArray{<:CartesianIndex} )
12
+ function scatter_kernel! (op:: OP , dst, src, idx:: CUDA.CuDeviceArray{<:CartesianIndex} ) where OP
13
13
index = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
14
14
15
15
@inbounds if index <= length (idx)
@@ -19,7 +19,7 @@ function scatter_kernel!(op, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}
19
19
return nothing
20
20
end
21
21
22
- function scatter_kernel! (op, dst, src, idx, max_idx, max_dims_idx, dims_size)
22
+ function scatter_kernel! (op:: OP , dst, src, idx, max_idx, max_dims_idx, dims_size) where OP
23
23
index = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
24
24
25
25
@inbounds if index <= max_idx
@@ -30,7 +30,8 @@ function scatter_kernel!(op, dst, src, idx, max_idx, max_dims_idx, dims_size)
30
30
return nothing
31
31
end
32
32
33
- function scatter_kernel! (op, dst, src, idx:: CUDA.CuDeviceArray{<:CartesianIndex} , max_idx, max_dims_idx, dims_size)
33
+ function scatter_kernel! (op:: OP , dst, src, idx:: CUDA.CuDeviceArray{<:CartesianIndex} ,
34
+ max_idx, max_dims_idx, dims_size) where OP
34
35
index = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
35
36
36
37
@inbounds if index <= max_idx
@@ -42,7 +43,7 @@ function scatter_kernel!(op, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}
42
43
return nothing
43
44
end
44
45
45
- function NNlib. scatter! (op, dst:: AnyCuArray , src:: AnyCuArray , idx:: AnyCuArray )
46
+ function NNlib. scatter! (op:: OP , dst:: AnyCuArray , src:: AnyCuArray , idx:: AnyCuArray ) where OP
46
47
dims = NNlib. scatter_dims (dst, src, idx)
47
48
args = if dims == 0
48
49
max_idx = length (idx)
72
73
73
74
# # Gradients
74
75
75
- function ∇scatter_src_kernel! (op, Δsrc, src, idx, rev_idx, max_idx, T)
76
+ function ∇scatter_src_kernel! (op:: OP , Δsrc, src, idx,
77
+ rev_idx, max_idx, T:: Type{TT} ) where {OP,TT}
76
78
index = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
77
79
78
80
@inbounds if index <= max_idx
@@ -91,7 +93,8 @@ function ∇scatter_src_kernel!(op, Δsrc, src, idx, rev_idx, max_idx, T)
91
93
return nothing
92
94
end
93
95
94
- function ∇scatter_src_kernel! (op, Δsrc, src, idx:: CUDA.CuDeviceArray{<:CartesianIndex} , rev_idx, max_idx, T)
96
+ function ∇scatter_src_kernel! (op:: OP , Δsrc, src, idx:: CUDA.CuDeviceArray{<:CartesianIndex} ,
97
+ rev_idx, max_idx, T:: Type{TT} ) where {OP,TT}
95
98
index = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
96
99
97
100
@inbounds if index <= max_idx
@@ -110,7 +113,8 @@ function ∇scatter_src_kernel!(op, Δsrc, src, idx::CUDA.CuDeviceArray{<:Cartes
110
113
return nothing
111
114
end
112
115
113
- function ∇scatter_src_kernel! (op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, T)
116
+ function ∇scatter_src_kernel! (op:: OP , Δsrc, src, idx,
117
+ rev_idx, pre_cart_idx, max_dims_idx, max_idx, T:: Type{TT} ) where {OP,TT}
114
118
index = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
115
119
116
120
@inbounds if index <= max_idx
@@ -132,7 +136,8 @@ function ∇scatter_src_kernel!(op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_
132
136
return nothing
133
137
end
134
138
135
- function ∇scatter_src_kernel! (op, Δsrc, src, idx:: CUDA.CuDeviceArray{<:CartesianIndex} , rev_idx, pre_cart_idx, max_dims_idx, max_idx, T)
139
+ function ∇scatter_src_kernel! (op:: OP , Δsrc, src, idx:: CUDA.CuDeviceArray{<:CartesianIndex} ,
140
+ rev_idx, pre_cart_idx, max_dims_idx, max_idx, T:: Type{TT} ) where {OP,TT}
136
141
index = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
137
142
138
143
@inbounds if index <= max_idx
0 commit comments