@@ -70,7 +70,7 @@ julia> NNlib.scatter!(*, fill(0.5, 2, 4), [1 10; 100 1000], [3,2])
70
70
0.5 500.0 50.0 0.5
71
71
```
72
72
"""
73
- function scatter! (op, dst:: AbstractArray , src:: AbstractArray , idx:: AbstractArray )
73
+ function scatter! (op:: OP , dst:: AbstractArray , src:: AbstractArray , idx:: AbstractArray ) where OP
74
74
dims = scatter_dims (dst, src, idx)
75
75
colons = Base. ntuple (_-> Colon (), dims)
76
76
for k in CartesianIndices (idx)
@@ -127,11 +127,11 @@ julia> NNlib.scatter(*, [10,200,3000], [1,4,2]; init = 10, dstsize = 6)
127
127
10
128
128
```
129
129
"""
130
- function scatter (op,
130
+ function scatter (op:: OP ,
131
131
src:: AbstractArray{Tsrc,Nsrc} ,
132
132
idx:: AbstractArray{Tidx,Nidx} ;
133
- init = nothing , dstsize = nothing ) where {Tsrc,Tidx,Nsrc,Nidx}
134
-
133
+ init = nothing , dstsize = nothing ) where {Tsrc,Tidx,Nsrc,Nidx,OP }
134
+
135
135
dims = Nsrc - Nidx
136
136
dstsz = isnothing (dstsize) ? (size (src)[1 : dims]. .. , maximum_dims (idx)... ) : dstsize
137
137
dst = similar (src, Tsrc, dstsz)
0 commit comments