Skip to content

Commit 84d8e12

Browse files
committed
Cache subed equations and only use full_equations in calculate_* functions
1 parent 7cc075d commit 84d8e12

File tree

6 files changed

+38
-21
lines changed

6 files changed

+38
-21
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ export Differential, expand_derivatives, @derivatives
181181
export Equation, ConstrainedEquation
182182
export Term, Sym
183183
export SymScope, LocalScope, ParentScope, GlobalScope
184-
export independent_variables, independent_variable, states, parameters, equations, controls, observed, structure
184+
export independent_variables, independent_variable, states, parameters, equations, controls, observed, structure, full_equations
185185
export structural_simplify, expand_connections
186186
export DiscreteSystem, DiscreteProblem
187187

src/structural_transformation/StructuralTransformations.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ using ModelingToolkit: ODESystem, AbstractSystem,var_from_nested_derivative, Dif
2121
ExtraVariablesSystemException,
2222
get_postprocess_fbody, vars!,
2323
IncrementalCycleTracker, add_edge_checked!, topological_sort,
24-
invalidate_cache!
24+
invalidate_cache!, Substitutions
2525

2626
using ModelingToolkit.BipartiteGraphs
2727
import .BipartiteGraphs: invview
@@ -43,6 +43,7 @@ export tearing_assignments, tearing_substitution
4343
export build_torn_function, build_observed_function, ODAEProblem
4444
export sorted_incidence_matrix
4545
export torn_system_jacobian_sparsity
46+
export full_equations
4647

4748
include("utils.jl")
4849
include("pantelides.jl")

src/structural_transformation/symbolics_tearing.jl

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ function tearing_sub(expr, dict, s)
3434
s ? simplify(expr) : expr
3535
end
3636

37-
function tearing_substitution(sys::AbstractSystem; simplify=false)
38-
empty_substitutions(sys) && return sys
39-
subs, = get_substitutions(sys)
37+
function full_equations(sys::AbstractSystem; simplify=false)
38+
empty_substitutions(sys) && return equations(sys)
39+
substitutions = get_substitutions(sys)
40+
substitutions.subed_eqs === nothing && return substitutions.subed_eqs
41+
@unpack subs = substitutions
4042
solved = Dict(eq.lhs => eq.rhs for eq in subs)
4143
neweqs = map(equations(sys)) do eq
4244
if isdiffeq(eq)
@@ -54,6 +56,12 @@ function tearing_substitution(sys::AbstractSystem; simplify=false)
5456
end
5557
eq
5658
end
59+
substitutions.subed_eqs = neweqs
60+
return neweqs
61+
end
62+
63+
function tearing_substitution(sys::AbstractSystem; kwargs...)
64+
neweqs = full_equations(sys::AbstractSystem; kwargs...)
5765
@set! sys.eqs = neweqs
5866
@set! sys.substitutions = nothing
5967
end
@@ -64,7 +72,7 @@ function tearing_assignments(sys::AbstractSystem)
6472
deps = Int[]
6573
sol_states = Code.LazyState()
6674
else
67-
subs, deps = get_substitutions(sys)
75+
@unpack subs, deps = get_substitutions(sys)
6876
assignments = [Assignment(eq.lhs, eq.rhs) for eq in subs]
6977
sol_states = Code.NameState(Dict(eq.lhs => Symbol(eq.lhs) for eq in subs))
7078
end
@@ -105,13 +113,13 @@ function tearing_reassemble(sys, var_eq_matching; simplify=false)
105113
end
106114
subgraph = substitution_graph(graph, solved_equations, solved_variables, var_eq_matching)
107115
toporder = topological_sort_by_dfs(subgraph)
108-
substitutions = [solve_equation(
109-
eqs[solved_equations[i]],
110-
fullvars[solved_variables[i]],
111-
simplify
112-
) for i in toporder]
116+
subeqs = [solve_equation(
117+
eqs[solved_equations[i]],
118+
fullvars[solved_variables[i]],
119+
simplify
120+
) for i in toporder]
113121
invtoporder = invperm(toporder)
114-
deps = [[invtoporder[n] for n in neighborhood(subgraph, j, Inf, dir=:in) if n!=j] for (i, j) in enumerate(toporder)]
122+
deps = [Int[invtoporder[n] for n in neighborhood(subgraph, j, Inf, dir=:in) if n!=j] for (i, j) in enumerate(toporder)]
115123

116124
# Rewrite remaining equations in terms of solved variables
117125

@@ -134,8 +142,8 @@ function tearing_reassemble(sys, var_eq_matching; simplify=false)
134142
@set! sys.structure = s
135143
@set! sys.eqs = neweqs
136144
@set! sys.states = [s.fullvars[idx] for idx in 1:length(s.fullvars) if !isdervar(s, idx)]
137-
@set! sys.observed = [observed(sys); substitutions]
138-
@set! sys.substitutions = substitutions, deps
145+
@set! sys.observed = [observed(sys); subeqs]
146+
@set! sys.substitutions = Substitutions(subeqs, deps)
139147
return sys
140148
end
141149

src/systems/abstractsystem.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,14 @@ Generate a function to evaluate the system's equations.
129129
"""
130130
function generate_function end
131131

132+
133+
mutable struct Substitutions
134+
subs::Vector{Equation}
135+
deps::Vector{Vector{Int}}
136+
subed_eqs::Union{Nothing,Vector{Equation}}
137+
end
138+
Substitutions(subs, deps) = Substitutions(subs, deps, nothing)
139+
132140
Base.nameof(sys::AbstractSystem) = getfield(sys, :name)
133141

134142
#Deprecated

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ function calculate_tgrad(sys::AbstractODESystem;
55
# We need to remove explicit time dependence on the state because when we
66
# have `u(t) * t` we want to have the tgrad to be `u(t)` instead of `u'(t) *
77
# t + u(t)`.
8-
rhs = [detime_dvs(eq.rhs) for eq equations(sys)]
8+
rhs = [detime_dvs(eq.rhs) for eq full_equations(sys)]
99
iv = get_iv(sys)
1010
xs = states(sys)
1111
rule = Dict(map((x, xt) -> xt=>x, detime_dvs.(xs), xs))
@@ -23,7 +23,7 @@ function calculate_jacobian(sys::AbstractODESystem;
2323
if cache isa Tuple && cache[2] == (sparse, simplify)
2424
return cache[1]
2525
end
26-
rhs = [eq.rhs for eq equations(sys)]
26+
rhs = [eq.rhs for eq full_equations(sys)]
2727

2828
iv = get_iv(sys)
2929
dvs = states(sys)
@@ -45,7 +45,7 @@ function calculate_control_jacobian(sys::AbstractODESystem;
4545
return cache[1]
4646
end
4747

48-
rhs = [eq.rhs for eq equations(sys)]
48+
rhs = [eq.rhs for eq full_equations(sys)]
4949

5050
iv = get_iv(sys)
5151
ctrls = controls(sys)
@@ -263,7 +263,7 @@ function time_varying_as_func(x, sys::AbstractTimeDependentSystem)
263263
end
264264

265265
function calculate_massmatrix(sys::AbstractODESystem; simplify=false)
266-
eqs = [eq for eq in equations(sys) if !isdifferenceeq(eq)]
266+
eqs = [eq for eq in full_equations(sys) if !isdifferenceeq(eq)]
267267
dvs = states(sys)
268268
M = zeros(length(eqs),length(eqs))
269269
state2idx = Dict(s => i for (i, s) in enumerate(dvs))
@@ -285,7 +285,7 @@ function jacobian_sparsity(sys::AbstractODESystem)
285285
sparsity = torn_system_jacobian_sparsity(sys)
286286
sparsity === nothing || return sparsity
287287

288-
jacobian_sparsity([eq.rhs for eq equations(sys)],
288+
jacobian_sparsity([eq.rhs for eq full_equations(sys)],
289289
[dv for dv in states(sys)])
290290
end
291291

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,15 +430,15 @@ isarray(x) = x isa AbstractArray || x isa Symbolics.Arr
430430
function empty_substitutions(sys)
431431
has_substitutions(sys) || return true
432432
subs = get_substitutions(sys)
433-
isnothing(subs) || isempty(last(subs))
433+
isnothing(subs) || isempty(subs.deps)
434434
end
435435

436436
function get_substitutions_and_solved_states(sys; no_postprocess=false)
437437
if empty_substitutions(sys)
438438
sol_states = Code.LazyState()
439439
pre = no_postprocess ? (ex -> ex) : get_postprocess_fbody(sys)
440440
else
441-
subs, = get_substitutions(sys)
441+
@unpack subs = get_substitutions(sys)
442442
sol_states = Code.NameState(Dict(eq.lhs => Symbol(eq.lhs) for eq in subs))
443443
if no_postprocess
444444
pre = ex -> Let(Assignment[Assignment(eq.lhs, eq.rhs) for eq in subs], ex)

0 commit comments

Comments
 (0)