Skip to content

Commit 81f1f55

Browse files
committed
make aggbuf init work for {Any} diffcache
this can happen for sparsity detection for example
1 parent c828048 commit 81f1f55

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

src/aggregators.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ end
2727
NaiveAggregator(f) = (im, batches) -> NaiveAggregator(im, batches, f)
2828

2929
function aggregate!(a::NaiveAggregator, aggbuf, data)
30-
fill!(aggbuf, zero(eltype(aggbuf)))
30+
fill!(aggbuf, _appropriate_zero(aggbuf))
3131
_aggregate!(a, a.batches, aggbuf, data)
3232
end
3333
function _aggregate!(a::NaiveAggregator, batches, aggbuf, data)
@@ -101,7 +101,7 @@ KAAggregator(im, batches, f) = KAAggregator(f, AggregationMap(im, batches))
101101

102102
function aggregate!(a::KAAggregator, aggbuf, data)
103103
am = a.m
104-
fill!(aggbuf, zero(eltype(aggbuf)))
104+
fill!(aggbuf, _appropriate_zero(aggbuf))
105105
_backend = get_backend(data)
106106
# kernel = agg_kernel!(_backend, 1024, length(am.map))
107107
# kernel(a.f, aggbuf, view(data, am.range), am.map)
@@ -140,7 +140,7 @@ SequentialAggregator(f) = (im, batches) -> SequentialAggregator(im, batches, f)
140140
SequentialAggregator(im, batches, f) = SequentialAggregator(f, AggregationMap(im, batches))
141141

142142
function aggregate!(a::SequentialAggregator, aggbuf, data)
143-
fill!(aggbuf, zero(eltype(aggbuf)))
143+
fill!(aggbuf, _appropriate_zero(aggbuf))
144144

145145
am = a.m
146146
@inbounds begin
@@ -169,7 +169,7 @@ PolyesterAggregator(im, batches, f) = PolyesterAggregator(f, _inv_aggregation_ma
169169

170170
function aggregate!(a::PolyesterAggregator, aggbuf, data)
171171
length(a.m) == length(aggbuf) || throw(DimensionMismatch("length of aggbuf and a.m must be equal"))
172-
fill!(aggbuf, zero(eltype(aggbuf)))
172+
fill!(aggbuf, _appropriate_zero(aggbuf))
173173

174174
maxdepth = mapreduce(x -> length(x[2]), max, a.m)
175175

@@ -196,7 +196,7 @@ ThreadedAggregator(im, batches, f) = ThreadedAggregator(f, _inv_aggregation_map(
196196

197197
function aggregate!(a::ThreadedAggregator, aggbuf, data)
198198
length(a.m) == length(aggbuf) || throw(DimensionMismatch("length of aggbuf and a.m must be equal"))
199-
fill!(aggbuf, zero(eltype(aggbuf)))
199+
fill!(aggbuf, _appropriate_zero(aggbuf))
200200

201201
Threads.@threads for (dstidx, srcidxs) in a.m
202202
@inbounds for srcidx in srcidxs
@@ -299,3 +299,11 @@ get_aggr_constructor(a::SparseAggregator) = SparseAggregator(+)
299299

300300
iscudacompatible(::Type{<:KAAggregator}) = true
301301
iscudacompatible(::Type{<:SparseAggregator}) = true
302+
303+
function _appropriate_zero(x)
304+
if isconcretetype(eltype(x))
305+
zero(eltype(x))
306+
else
307+
0.0 # hopefully that casts to what is needed
308+
end
309+
end

test/aggregators_test.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ using Chairmarks
66
using InteractiveUtils
77
using Test
88
using StableRNGs
9+
using ForwardDiff: Dual
10+
using Symbolics
911

1012
(isinteractive() && @__MODULE__()==Main ? includet : include)("ComponentLibrary.jl")
1113

@@ -61,3 +63,14 @@ using StableRNGs
6163
@test issame
6264
end
6365
end
66+
67+
@testset "Test _appropriate_zero" begin
68+
@test NetworkDynamics._appropriate_zero([1,2,3]) isa Int
69+
@test NetworkDynamics._appropriate_zero([1,2,3]) == 0
70+
@test NetworkDynamics._appropriate_zero([1.0,2,3]) isa Float64
71+
@test NetworkDynamics._appropriate_zero([1.0,2,3]) == 0.0
72+
@test NetworkDynamics._appropriate_zero([Dual(1.0), 2, 3]) == Dual(0.0)
73+
@variables x, y, z
74+
@test NetworkDynamics._appropriate_zero([x,y,z]) isa Num
75+
@test NetworkDynamics._appropriate_zero(Any[x,y,z]) isa Float64
76+
end

0 commit comments

Comments
 (0)