Skip to content

Commit 300fef0

Browse files
refactor: rewrite benchmarks to account for new derivative calculation
1 parent 7258bd2 commit 300fef0

File tree

3 files changed

+266
-18
lines changed

3 files changed

+266
-18
lines changed

benchmarks/Symbolics/BCR.jmd

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jacobian, generate a function to calculate it and call the function.
1717

1818

1919
```julia
20-
using OrdinaryDiffEq, Catalyst, ReactionNetworkImporters,
20+
using Catalyst, ReactionNetworkImporters,
2121
TimerOutputs, LinearAlgebra, ModelingToolkit, Chairmarks,
2222
LinearSolve, Symbolics, SymbolicUtils, SymbolicUtils.Code, SparseArrays, CairoMakie,
2323
PrettyTables
@@ -36,15 +36,29 @@ osys = convert(ODESystem, rn)
3636
rhs = [eq.rhs for eq in full_equations(osys)]
3737
vars = unknowns(osys)
3838
pars = parameters(osys)
39+
```
40+
41+
The `sparsejacobian` function in Symbolics.jl is optimized for hashconsing and caching, and as such
42+
performs very poorly without either of those features. We use the old implementation, optimized without
43+
hashconsing, to benchmark performance without hashconsing and without caching to avoid biasing the results.
3944

45+
```julia
46+
include("old_sparse_jacobian.jl")
47+
```
48+
49+
```julia
4050
SymbolicUtils.ENABLE_HASHCONSING[] = false
41-
@timeit to "Calculate jacobian - without hashconsing" jac_nohc = Symbolics.sparsejacobian(rhs, vars);
51+
@timeit to "Calculate jacobian - without hashconsing" jac_nohc = old_sparsejacobian(rhs, vars);
4252
SymbolicUtils.ENABLE_HASHCONSING[] = true
43-
SymbolicUtils.toggle_caching!(Symbolics.occursin_info, false)
44-
@timeit to "Calculate jacobian - hashconsing, without caching" jac_hc_nocache = Symbolics.sparsejacobian(rhs, vars);
45-
SymbolicUtils.toggle_caching!(Symbolics.occursin_info, true)
46-
stats = SymbolicUtils.get_stats(Symbolics.occursin_info)
47-
@assert stats.hits == stats.misses == 0
53+
Symbolics.toggle_derivative_caching!(false)
54+
Symbolics.clear_derivative_caches!()
55+
@timeit to "Calculate jacobian - hashconsing, without caching" jac_hc_nocache = old_sparsejacobian(rhs, vars);
56+
Symbolics.toggle_derivative_caching!(true)
57+
for fn in Symbolics.cached_derivative_functions()
58+
stats = SymbolicUtils.get_stats(fn)
59+
@assert stats.hits == stats.misses == 0
60+
end
61+
Symbolics.clear_derivative_caches!()
4862
@timeit to "Calculate jacobian - hashconsing and caching" jac_hc_cache = Symbolics.sparsejacobian(rhs, vars);
4963

5064
@assert isequal(jac_nohc, jac_hc_nocache)
@@ -87,19 +101,16 @@ We'll also measure scaling.
87101
function run_and_time_construct!(rhs, vars, pars, iv, N, i, jac_times, jac_allocs, build_times, functions)
88102
outputs = rhs[1:N]
89103
SymbolicUtils.ENABLE_HASHCONSING[] = false
90-
jac_result = @be Symbolics.sparsejacobian(outputs, vars)
104+
jac_result = @be old_sparsejacobian(outputs, vars)
91105
jac_times[1][i] = minimum(x -> x.time, jac_result.samples)
92106
jac_allocs[1][i] = minimum(x -> x.bytes, jac_result.samples)
93-
jac_nohc = Symbolics.sparsejacobian(outputs, vars)
94107

95108
SymbolicUtils.ENABLE_HASHCONSING[] = true
96-
jac_result = @be (SymbolicUtils.clear_cache!(Symbolics.occursin_info); Symbolics.sparsejacobian(outputs, vars))
109+
jac_result = @be (Symbolics.clear_derivative_caches!(); Symbolics.sparsejacobian(outputs, vars))
97110
jac_times[2][i] = minimum(x -> x.time, jac_result.samples)
98111
jac_allocs[2][i] = minimum(x -> x.bytes, jac_result.samples)
99-
jac_hc = Symbolics.sparsejacobian(outputs, vars)
100112

101-
@assert isequal(jac_nohc, jac_hc)
102-
jac = jac_hc
113+
jac = Symbolics.sparsejacobian(outputs, vars)
103114
args = (vars, pars, iv)
104115
kwargs = (; iip_config = (false, true), expression = Val{true})
105116

benchmarks/Symbolics/ThermalFluid.jmd

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,23 +250,34 @@ end
250250
function call(fn, args...)
251251
fn(args...)
252252
end
253+
```
254+
255+
The `sparsejacobian` function in Symbolics.jl is optimized for hashconsing and caching, and as such
256+
performs very poorly without either of those features. We use the old implementation, optimized without
257+
hashconsing, to benchmark performance without hashconsing and without caching to avoid biasing the results.
258+
259+
```julia
260+
include("old_sparse_jacobian.jl")
253261

254262
function run_and_time_construction!(jacobian_times, jacobian_gctimes, jacobian_allocs, build_times, functions, i, N)
255263
@mtkbuild sys = TestBenchPreinsulated(L=470, N=N, dn=0.3127, t_layer=[0.0056, 0.058])
264+
rhs = [eq.rhs for eq in full_equations(sys)]
265+
dvs = unknowns(sys)
266+
256267
@info "Built system"
257268
SymbolicUtils.ENABLE_HASHCONSING[] = false
258-
jac_result = @be calculate_jacobian(sys; sparse = true)
269+
jac_result = @be old_sparsejacobian(rhs, dvs)
259270
@info "No hashconsing benchmark"
260-
jac_nocse = calculate_jacobian(sys; sparse = true)
271+
jac_nocse = old_sparsejacobian(rhs, dvs)
261272
@info "No hashconsing result"
262273
jacobian_times[1][i] = mean(x -> x.time, jac_result.samples)
263274
jacobian_gctimes[1][i] = mean(x -> x.time * x.gc_fraction, jac_result.samples)
264275
jacobian_allocs[1][i] = mean(x -> x.bytes, jac_result.samples)
265276
@info "times" jacobian_times[1][i] jacobian_gctimes[1][i] jacobian_allocs[1][i]
266277
SymbolicUtils.ENABLE_HASHCONSING[] = true
267-
jac_result = @be (SymbolicUtils.clear_cache!(Symbolics.occursin_info); calculate_jacobian(sys; sparse = true))
278+
jac_result = @be (Symbolics.clear_derivative_caches!(); Symbolics.sparsejacobian(rhs, dvs))
268279
@info "Hashconsing benchmark"
269-
jac_cse = calculate_jacobian(sys; sparse = true)
280+
jac_cse = Symbolics.sparsejacobian(rhs, dvs)
270281
@info "Hashconsing result"
271282
jacobian_times[2][i] = mean(x -> x.time, jac_result.samples)
272283
jacobian_gctimes[2][i] = mean(x -> x.time * x.gc_fraction, jac_result.samples)
@@ -275,7 +286,6 @@ function run_and_time_construction!(jacobian_times, jacobian_gctimes, jacobian_a
275286
@assert isequal(jac_nocse, jac_cse)
276287
jac = jac_cse
277288

278-
dvs = unknowns(sys)
279289
ps = parameters(sys)
280290
defs = defaults(sys)
281291
u0 = Float64[Symbolics.fixpoint_sub(v, defs) for v in dvs]
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
function old_sparsejacobian(ops::AbstractVector, vars::AbstractVector)
2+
sp = Symbolics.jacobian_sparsity(ops, vars)
3+
I,J,_ = findnz(sp)
4+
5+
exprs = old_sparsejacobian_vals(ops, vars, I, J)
6+
7+
sparse(I, J, exprs, length(ops), length(vars))
8+
end
9+
10+
function old_sparsejacobian_vals(ops::AbstractVector, vars::AbstractVector, I::AbstractVector, J::AbstractVector; simplify::Bool=false, kwargs...)
11+
exprs = Num[]
12+
sizehint!(exprs, length(I))
13+
14+
for (i,j) in zip(I, J)
15+
push!(exprs, Num(old_expand_derivatives(Differential(vars[j])(ops[i]), simplify; kwargs...)))
16+
end
17+
exprs
18+
end
19+
20+
21+
function old_expand_derivatives(O::SymbolicUtils.Symbolic, simplify=false; throw_no_derivative=false)
22+
if iscall(O) && isa(operation(O), Differential)
23+
arg = only(arguments(O))
24+
arg = old_expand_derivatives(arg, false; throw_no_derivative)
25+
return old_executediff(operation(O), arg, simplify; throw_no_derivative)
26+
elseif iscall(O) && isa(operation(O), Integral)
27+
return operation(O)(old_expand_derivatives(arguments(O)[1]; throw_no_derivative))
28+
elseif !Symbolics.hasderiv(O)
29+
return O
30+
else
31+
args = map(a->old_expand_derivatives(a, false; throw_no_derivative), arguments(O))
32+
O1 = operation(O)(args...)
33+
return simplify ? SymbolicUtils.simplify(O1) : O1
34+
end
35+
end
36+
function old_expand_derivatives(n::Num, simplify=false; kwargs...)
37+
Symbolics.wrap(old_expand_derivatives(Symbolics.value(n), simplify; kwargs...))
38+
end
39+
40+
function old_occursin_info(x, expr, fail = true)
41+
if SymbolicUtils.symtype(expr) <: AbstractArray
42+
if fail
43+
error("Differentiation with array expressions is not yet supported")
44+
else
45+
return occursin(x, expr)
46+
end
47+
end
48+
49+
# Allow scalarized expressions
50+
function is_scalar_indexed(ex)
51+
(iscall(ex) && operation(ex) == getindex && !(SymbolicUtils.symtype(ex) <: AbstractArray)) ||
52+
(iscall(ex) && (SymbolicUtils.issym(operation(ex)) || iscall(operation(ex))) &&
53+
is_scalar_indexed(operation(ex)))
54+
end
55+
56+
# x[1] == x[1] but not x[2]
57+
if is_scalar_indexed(x) && is_scalar_indexed(expr) &&
58+
isequal(first(arguments(x)), first(arguments(expr)))
59+
return isequal(operation(x), operation(expr)) &&
60+
isequal(arguments(x), arguments(expr))
61+
end
62+
63+
if is_scalar_indexed(x) && is_scalar_indexed(expr) &&
64+
!occursin(first(arguments(x)), first(arguments(expr)))
65+
return false
66+
end
67+
68+
if is_scalar_indexed(expr) && !is_scalar_indexed(x) && !occursin(x, expr)
69+
return false
70+
end
71+
72+
!iscall(expr) && return isequal(x, expr)
73+
if isequal(x, expr)
74+
true
75+
else
76+
args = map(a->old_occursin_info(x, a, operation(expr) !== getindex), arguments(expr))
77+
if all(_isfalse, args)
78+
return false
79+
end
80+
Term{Real}(true, args)
81+
end
82+
end
83+
84+
function old_occursin_info(x, expr::Sym, fail)
85+
if SymbolicUtils.symtype(expr) <: AbstractArray && fail
86+
error("Differentiation of expressions involving arrays and array variables is not yet supported.")
87+
end
88+
isequal(x, expr)
89+
end
90+
91+
_isfalse(occ::Bool) = occ === false
92+
_isfalse(occ::SymbolicUtils.Symbolic) = iscall(occ) && _isfalse(operation(occ))
93+
94+
_iszero(x) = false
95+
_isone(x) = false
96+
_iszero(x::Number) = iszero(x)
97+
_isone(x::Number) = isone(x)
98+
_iszero(::SymbolicUtils.Symbolic) = false
99+
_isone(::SymbolicUtils.Symbolic) = false
100+
_iszero(x::Num) = _iszero(value(x))::Bool
101+
_isone(x::Num) = _isone(value(x))::Bool
102+
103+
104+
function old_executediff(D, arg, simplify=false; occurrences=nothing, throw_no_derivative=false)
105+
if occurrences == nothing
106+
occurrences = old_occursin_info(D.x, arg)
107+
end
108+
109+
_isfalse(occurrences) && return 0
110+
occurrences isa Bool && return 1 # means it's a `true`
111+
112+
if !iscall(arg)
113+
return D(arg) # Cannot expand
114+
elseif (op = operation(arg); SymbolicUtils.issym(op))
115+
inner_args = arguments(arg)
116+
if any(isequal(D.x), inner_args)
117+
return D(arg) # base case if any argument is directly equal to the i.v.
118+
else
119+
return sum(inner_args, init=0) do a
120+
return old_executediff(Differential(a), arg; throw_no_derivative) *
121+
old_executediff(D, a; throw_no_derivative)
122+
end
123+
end
124+
elseif op === getindex
125+
inner_args = arguments(arguments(arg)[1])
126+
c = 0
127+
for a in inner_args
128+
if isequal(a, D.x)
129+
return D(arg)
130+
else
131+
c += Differential(a)(arg) * D(a)
132+
end
133+
end
134+
return old_expand_derivatives(c)
135+
elseif op === ifelse
136+
args = arguments(arg)
137+
O = op(args[1],
138+
old_executediff(D, args[2], simplify; occurrences=arguments(occurrences)[2], throw_no_derivative),
139+
old_executediff(D, args[3], simplify; occurrences=arguments(occurrences)[3], throw_no_derivative))
140+
return O
141+
elseif isa(op, Differential)
142+
# The recursive expand_derivatives was not able to remove
143+
# a nested Differential. We can attempt to differentiate the
144+
# inner expression wrt to the outer iv. And leave the
145+
# unexpandable Differential outside.
146+
if isequal(op.x, D.x)
147+
return D(arg)
148+
else
149+
inner = old_executediff(D, arguments(arg)[1], false; throw_no_derivative)
150+
# if the inner expression is not expandable either, return
151+
if iscall(inner) && operation(inner) isa Differential
152+
return D(arg)
153+
else
154+
# otherwise give the nested Differential another try
155+
return old_executediff(op, inner, simplify; throw_no_derivative)
156+
end
157+
end
158+
elseif isa(op, Integral)
159+
if isa(op.domain.domain, Symbolics.AbstractInterval)
160+
domain = op.domain.domain
161+
a, b = Symbolics.DomainSets.endpoints(domain)
162+
c = 0
163+
inner_function = arguments(arg)[1]
164+
if iscall(value(a))
165+
t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(a)))
166+
t2 = D(a)
167+
c -= t1*t2
168+
end
169+
if iscall(value(b))
170+
t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(b)))
171+
t2 = D(b)
172+
c += t1*t2
173+
end
174+
inner = old_executediff(D, arguments(arg)[1]; throw_no_derivative)
175+
c += op(inner)
176+
return Symbolics.value(c)
177+
end
178+
end
179+
180+
inner_args = arguments(arg)
181+
l = length(inner_args)
182+
exprs = []
183+
c = 0
184+
185+
for i in 1:l
186+
t2 = old_executediff(D, inner_args[i],false; occurrences=arguments(occurrences)[i], throw_no_derivative)
187+
188+
x = if _iszero(t2)
189+
t2
190+
elseif _isone(t2)
191+
d = Symbolics.derivative_idx(arg, i)
192+
if d isa Symbolics.NoDeriv
193+
throw_no_derivative && error((arg, i))
194+
D(arg)
195+
else
196+
d
197+
end
198+
else
199+
t1 = Symbolics.derivative_idx(arg, i)
200+
t1 = if t1 isa Symbolics.NoDeriv
201+
throw_no_derivative && error((arg, i))
202+
D(arg)
203+
else
204+
t1
205+
end
206+
t1 * t2
207+
end
208+
209+
if _iszero(x)
210+
continue
211+
elseif x isa SymbolicUtils.Symbolic
212+
push!(exprs, x)
213+
else
214+
c += x
215+
end
216+
end
217+
218+
if isempty(exprs)
219+
return c
220+
elseif length(exprs) == 1
221+
term = (simplify ? SymbolicUtils.simplify(exprs[1]) : exprs[1])
222+
return _iszero(c) ? term : c + term
223+
else
224+
x = +((!_iszero(c) ? vcat(c, exprs) : exprs)...)
225+
return simplify ? SymbolicUtils.simplify(x) : x
226+
end
227+
end

0 commit comments

Comments
 (0)