Skip to content

Commit c049454

Browse files
committed
Add fast_substitute
Before: ```julia julia> @time sysRed = tearing(sysEx); 8.903631 seconds (42.38 M allocations: 2.968 GiB, 8.06% gc time) ``` After: ```julia julia> @time tearing(sysEx); 1.733097 seconds (10.90 M allocations: 1.059 GiB, 19.44% gc time) ```
1 parent 022008b commit c049454

File tree

4 files changed

+40
-7
lines changed

4 files changed

+40
-7
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
2222
get_postprocess_fbody, vars!,
2323
IncrementalCycleTracker, add_edge_checked!, topological_sort,
2424
invalidate_cache!, Substitutions, get_or_construct_tearing_state,
25-
AliasGraph, filter_kwargs, lower_varname, setio, SparseMatrixCLIL
25+
AliasGraph, filter_kwargs, lower_varname, setio, SparseMatrixCLIL,
26+
fast_substitute
2627

2728
using ModelingToolkit.BipartiteGraphs
2829
import .BipartiteGraphs: invview, complete

src/structural_transformation/symbolics_tearing.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
227227
idx_buffer = Int[]
228228
sub_callback! = let eqs = neweqs, fullvars = fullvars
229229
(ieq, s) -> begin
230-
neweq = substitute(eqs[ieq], fullvars[s[1]] => fullvars[s[2]])
230+
neweq = fast_substitute(eqs[ieq], fullvars[s[1]] => fullvars[s[2]])
231231
eqs[ieq] = neweq
232232
end
233233
end
@@ -282,7 +282,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
282282
end
283283
for eq in 𝑑neighbors(graph, dv)
284284
dummy_sub[dd] = v_t
285-
neweqs[eq] = substitute(neweqs[eq], dd => v_t)
285+
neweqs[eq] = fast_substitute(neweqs[eq], dd => v_t)
286286
end
287287
fullvars[dv] = v_t
288288
# If we have:
@@ -295,7 +295,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
295295
while (ddx = var_to_diff[dx]) !== nothing
296296
dx_t = D(x_t)
297297
for eq in 𝑑neighbors(graph, ddx)
298-
neweqs[eq] = substitute(neweqs[eq], fullvars[ddx] => dx_t)
298+
neweqs[eq] = fast_substitute(neweqs[eq], fullvars[ddx] => dx_t)
299299
end
300300
fullvars[ddx] = dx_t
301301
dx = ddx
@@ -655,7 +655,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
655655
obs_sub[eq.lhs] = eq.rhs
656656
end
657657
# TODO: compute the dependency correctly so that we don't have to do this
658-
obs = substitute.([oldobs; subeqs], (obs_sub,))
658+
obs = fast_substitute([oldobs; subeqs], obs_sub)
659659
@set! sys.observed = obs
660660
@set! state.sys = sys
661661
@set! sys.tearing_state = state

src/systems/alias_elimination.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,9 @@ function alias_elimination!(state::TearingState; kwargs...)
151151
k === nothing && break
152152
end
153153
end
154-
subfun = Base.Fix2(substitute, subs)
155154
for ieq in eqs_to_update
156155
eq = eqs[ieq]
157-
eqs[ieq] = subfun(eq.lhs) ~ subfun(eq.rhs)
156+
eqs[ieq] = fast_substitute(eq, subs)
158157
end
159158

160159
for old_ieq in to_expand

src/utils.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,3 +741,36 @@ function jacobian_wrt_vars(pf::F, p, input_idxs, chunk::C) where {F, C}
741741
cfg = ForwardDiff.JacobianConfig(p_closure, p_small, chunk, tag)
742742
ForwardDiff.jacobian(p_closure, p_small, cfg, Val(false))
743743
end
744+
745+
# Symbolics needs to call unwrap on the substitution rules, but most of the time
746+
# we don't want to do that in MTK.
747+
function fast_substitute(eq::Equation, subs)
748+
fast_substitute(eq.lhs, subs) ~ fast_substitute(eq.rhs, subs)
749+
end
750+
function fast_substitute(eq::Equation, subs::Pair)
751+
fast_substitute(eq.lhs, subs) ~ fast_substitute(eq.rhs, subs)
752+
end
753+
fast_substitute(eqs::AbstractArray{Equation}, subs) = fast_substitute.(eqs, (subs,))
754+
fast_substitute(a, b) = substitute(a, b)
755+
function fast_substitute(expr, pair::Pair)
756+
a, b = pair
757+
isequal(expr, a) && return b
758+
759+
istree(expr) || return expr
760+
op = fast_substitute(operation(expr), pair)
761+
canfold = Ref(!(op isa Symbolic))
762+
args = let canfold = canfold
763+
map(SymbolicUtils.unsorted_arguments(expr)) do x
764+
x′ = fast_substitute(x, pair)
765+
canfold[] = canfold[] && !(x′ isa Symbolic)
766+
x′
767+
end
768+
end
769+
canfold[] && return op(args...)
770+
771+
similarterm(expr,
772+
op,
773+
args,
774+
symtype(expr);
775+
metadata = metadata(expr))
776+
end

0 commit comments

Comments
 (0)