Skip to content

Commit 2595b92

Browse files
Merge pull request #43 from FluxML/cl/scatter
specialize on op in scatter
2 parents 32919ac + 48fbdc2 commit 2595b92

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

ext/NNlibCUDA/src/scatter.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# supported op: +, -, *, /, max, min, &, |, mean
22

3-
function scatter_kernel!(op, dst, src, idx)
3+
function scatter_kernel!(op::OP, dst, src, idx) where OP
44
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
55

66
@inbounds if index <= length(idx)
@@ -9,7 +9,7 @@ function scatter_kernel!(op, dst, src, idx)
99
return nothing
1010
end
1111

12-
function scatter_kernel!(op, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex})
12+
function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}) where OP
1313
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
1414

1515
@inbounds if index <= length(idx)
@@ -19,7 +19,7 @@ function scatter_kernel!(op, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}
1919
return nothing
2020
end
2121

22-
function scatter_kernel!(op, dst, src, idx, max_idx, max_dims_idx, dims_size)
22+
function scatter_kernel!(op::OP, dst, src, idx, max_idx, max_dims_idx, dims_size) where OP
2323
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
2424

2525
@inbounds if index <= max_idx
@@ -30,7 +30,8 @@ function scatter_kernel!(op, dst, src, idx, max_idx, max_dims_idx, dims_size)
3030
return nothing
3131
end
3232

33-
function scatter_kernel!(op, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, max_idx, max_dims_idx, dims_size)
33+
function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},
34+
max_idx, max_dims_idx, dims_size) where OP
3435
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
3536

3637
@inbounds if index <= max_idx
@@ -42,7 +43,7 @@ function scatter_kernel!(op, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}
4243
return nothing
4344
end
4445

45-
function NNlib.scatter!(op, dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray)
46+
function NNlib.scatter!(op::OP, dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) where OP
4647
dims = NNlib.scatter_dims(dst, src, idx)
4748
args = if dims == 0
4849
max_idx = length(idx)
@@ -72,7 +73,8 @@ end
7273

7374
## Gradients
7475

75-
function ∇scatter_src_kernel!(op, Δsrc, src, idx, rev_idx, max_idx, T)
76+
function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,
77+
rev_idx, max_idx, T::Type{TT}) where {OP,TT}
7678
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
7779

7880
@inbounds if index <= max_idx
@@ -91,7 +93,8 @@ function ∇scatter_src_kernel!(op, Δsrc, src, idx, rev_idx, max_idx, T)
9193
return nothing
9294
end
9395

94-
function ∇scatter_src_kernel!(op, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, rev_idx, max_idx, T)
96+
function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},
97+
rev_idx, max_idx, T::Type{TT}) where {OP,TT}
9598
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
9699

97100
@inbounds if index <= max_idx
@@ -110,7 +113,8 @@ function ∇scatter_src_kernel!(op, Δsrc, src, idx::CUDA.CuDeviceArray{<:Cartes
110113
return nothing
111114
end
112115

113-
function ∇scatter_src_kernel!(op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, T)
116+
function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx,
117+
rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT}
114118
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
115119

116120
@inbounds if index <= max_idx
@@ -132,7 +136,8 @@ function ∇scatter_src_kernel!(op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_
132136
return nothing
133137
end
134138

135-
function ∇scatter_src_kernel!(op, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, rev_idx, pre_cart_idx, max_dims_idx, max_idx, T)
139+
function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex},
140+
rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT}
136141
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
137142

138143
@inbounds if index <= max_idx

0 commit comments

Comments
 (0)