27
27
NaiveAggregator (f) = (im, batches) -> NaiveAggregator (im, batches, f)
28
28
29
29
function aggregate! (a:: NaiveAggregator , aggbuf, data)
30
- fill! (aggbuf, zero ( eltype ( aggbuf) ))
30
+ fill! (aggbuf, _appropriate_zero ( aggbuf))
31
31
_aggregate! (a, a. batches, aggbuf, data)
32
32
end
33
33
function _aggregate! (a:: NaiveAggregator , batches, aggbuf, data)
@@ -101,7 +101,7 @@ KAAggregator(im, batches, f) = KAAggregator(f, AggregationMap(im, batches))
101
101
102
102
function aggregate! (a:: KAAggregator , aggbuf, data)
103
103
am = a. m
104
- fill! (aggbuf, zero ( eltype ( aggbuf) ))
104
+ fill! (aggbuf, _appropriate_zero ( aggbuf))
105
105
_backend = get_backend (data)
106
106
# kernel = agg_kernel!(_backend, 1024, length(am.map))
107
107
# kernel(a.f, aggbuf, view(data, am.range), am.map)
@@ -140,7 +140,7 @@ SequentialAggregator(f) = (im, batches) -> SequentialAggregator(im, batches, f)
140
140
SequentialAggregator (im, batches, f) = SequentialAggregator (f, AggregationMap (im, batches))
141
141
142
142
function aggregate! (a:: SequentialAggregator , aggbuf, data)
143
- fill! (aggbuf, zero ( eltype ( aggbuf) ))
143
+ fill! (aggbuf, _appropriate_zero ( aggbuf))
144
144
145
145
am = a. m
146
146
@inbounds begin
@@ -169,7 +169,7 @@ PolyesterAggregator(im, batches, f) = PolyesterAggregator(f, _inv_aggregation_ma
169
169
170
170
function aggregate! (a:: PolyesterAggregator , aggbuf, data)
171
171
length (a. m) == length (aggbuf) || throw (DimensionMismatch (" length of aggbuf and a.m must be equal" ))
172
- fill! (aggbuf, zero ( eltype ( aggbuf) ))
172
+ fill! (aggbuf, _appropriate_zero ( aggbuf))
173
173
174
174
maxdepth = mapreduce (x -> length (x[2 ]), max, a. m)
175
175
@@ -196,7 +196,7 @@ ThreadedAggregator(im, batches, f) = ThreadedAggregator(f, _inv_aggregation_map(
196
196
197
197
function aggregate! (a:: ThreadedAggregator , aggbuf, data)
198
198
length (a. m) == length (aggbuf) || throw (DimensionMismatch (" length of aggbuf and a.m must be equal" ))
199
- fill! (aggbuf, zero ( eltype ( aggbuf) ))
199
+ fill! (aggbuf, _appropriate_zero ( aggbuf))
200
200
201
201
Threads. @threads for (dstidx, srcidxs) in a. m
202
202
@inbounds for srcidx in srcidxs
@@ -299,3 +299,11 @@ get_aggr_constructor(a::SparseAggregator) = SparseAggregator(+)
299
299
300
300
iscudacompatible (:: Type{<:KAAggregator} ) = true
301
301
iscudacompatible (:: Type{<:SparseAggregator} ) = true
302
+
303
+ function _appropriate_zero (x)
304
+ if isconcretetype (eltype (x))
305
+ zero (eltype (x))
306
+ else
307
+ 0.0 # hopefully that casts to what is needed
308
+ end
309
+ end
0 commit comments