Skip to content

Commit d7927f3

Browse files
trigger op specialization in scatter (#384)
* trigger op specialization in scatter
1 parent 0400358 commit d7927f3

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/scatter.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ julia> NNlib.scatter!(*, fill(0.5, 2, 4), [1 10; 100 1000], [3,2])
7070
0.5 500.0 50.0 0.5
7171
```
7272
"""
73-
function scatter!(op, dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
73+
function scatter!(op::OP, dst::AbstractArray, src::AbstractArray, idx::AbstractArray) where OP
7474
dims = scatter_dims(dst, src, idx)
7575
colons = Base.ntuple(_->Colon(), dims)
7676
for k in CartesianIndices(idx)
@@ -127,11 +127,11 @@ julia> NNlib.scatter(*, [10,200,3000], [1,4,2]; init = 10, dstsize = 6)
127127
10
128128
```
129129
"""
130-
function scatter(op,
130+
function scatter(op::OP,
131131
src::AbstractArray{Tsrc,Nsrc},
132132
idx::AbstractArray{Tidx,Nidx};
133-
init = nothing, dstsize = nothing) where {Tsrc,Tidx,Nsrc,Nidx}
134-
133+
init = nothing, dstsize = nothing) where {Tsrc,Tidx,Nsrc,Nidx,OP}
134+
135135
dims = Nsrc - Nidx
136136
dstsz = isnothing(dstsize) ? (size(src)[1:dims]..., maximum_dims(idx)...) : dstsize
137137
dst = similar(src, Tsrc, dstsz)

0 commit comments

Comments
 (0)