Skip to content

Commit a2111de

Browse files
committed
Cache tearing state
1 parent 52cace3 commit a2111de

File tree

8 files changed

+48
-21
lines changed

8 files changed

+48
-21
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
1515
operation, arguments, Sym, Term, simplify, solve_for,
1616
isdiffeq, isdifferential, isinput,
1717
empty_substitutions, get_substitutions,
18-
get_structure, get_iv, independent_variables,
19-
has_structure, defaults, InvalidSystemException,
18+
get_tearing_state, get_iv, independent_variables,
19+
has_tearing_state, defaults, InvalidSystemException,
2020
ExtraEquationsSystemException,
2121
ExtraVariablesSystemException,
2222
get_postprocess_fbody, vars!,
2323
IncrementalCycleTracker, add_edge_checked!, topological_sort,
24-
invalidate_cache!, Substitutions
24+
invalidate_cache!, Substitutions, get_or_construct_tearing_state
2525

2626
using ModelingToolkit.BipartiteGraphs
2727
import .BipartiteGraphs: invview

src/structural_transformation/codegen.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ function build_torn_function(
245245
push!(rhss, eq.rhs)
246246
end
247247

248-
state = TearingState(sys)
248+
state = get_or_construct_tearing_state(sys)
249249
fullvars = state.fullvars
250250
var_eq_matching, var_sccs = algebraic_variables_scc(state)
251251
condensed_graph = MatchedCondensationGraph(

src/structural_transformation/symbolics_tearing.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,19 +220,24 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify=false
220220

221221
# Contract the vertices in the structure graph to make the structure match
222222
# the new reality of the system we've just created.
223-
#graph = contract_variables(graph, var_eq_matching, solved_variables)
223+
graph = contract_variables(graph, var_eq_matching, solved_variables)
224224

225225
# Update system
226226
active_vars = setdiff(BitSet(1:length(fullvars)), solved_variables)
227227

228+
@set! state.structure.graph = graph
229+
@set! state.fullvars = [v for (i, v) in enumerate(fullvars) if i in active_vars]
230+
228231
sys = state.sys
229232
@set! sys.eqs = neweqs
230233
isstatediff(i) = var_eq_matching[i] !== SelectedState() && invview(var_to_diff)[i] !== nothing && var_eq_matching[invview(var_to_diff)[i]] === SelectedState()
231234
@set! sys.states = [fullvars[i] for i in active_vars if !isstatediff(i)]
232235
@set! sys.observed = [observed(sys); subeqs]
233236
@set! sys.substitutions = Substitutions(subeqs, deps)
237+
@set! state.sys = sys
238+
@set! sys.tearing_state = state
234239

235-
return sys
240+
return invalidate_cache!(sys)
236241
end
237242

238243
function tearing(state::TearingState)

src/structural_transformation/utils.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,11 @@ function uneven_invmap(n::Int, list)
258258
end
259259

260260
function torn_system_jacobian_sparsity(sys)
261-
has_structure(sys) || return nothing
262-
get_structure(sys) isa SystemStructure || return nothing
261+
state = get_tearing_state(sys)
262+
state isa TearingState || return nothing
263263
s = structure(sys)
264-
@unpack fullvars, graph = s
264+
graph = state.structure.graph
265+
fullvars = state.fullvars
265266

266267
states_idxs = findall(!isdifferential, fullvars)
267268
var2idx = Dict{Int,Int}(v => i for (i, v) in enumerate(states_idxs))

src/systems/abstractsystem.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ for prop in [
225225
:connections
226226
:preface
227227
:torn_matching
228+
:tearing_state
228229
:substitutions
229230
]
230231
fname1 = Symbol(:get_, prop)
@@ -683,6 +684,18 @@ end
683684

684685
Base.write(io::IO, sys::AbstractSystem) = write(io, readable_code(toexpr(sys)))
685686

687+
function get_or_construct_tearing_state(sys)
688+
if has_tearing_state(sys)
689+
state = get_tearing_state(sys)
690+
if state === nothing
691+
state = TearingState(sys)
692+
end
693+
else
694+
state = nothing
695+
end
696+
state
697+
end
698+
686699
function Base.show(io::IO, ::MIME"text/plain", sys::AbstractSystem)
687700
eqs = equations(sys)
688701
if eqs isa AbstractArray
@@ -741,7 +754,7 @@ function Base.show(io::IO, ::MIME"text/plain", sys::AbstractSystem)
741754
if has_torn_matching(sys)
742755
# If the system can take a torn matching, then we can initialize a tearing
743756
# state on it. Do so and get show the structure.
744-
state = TearingState(sys; check=false)
757+
state = get_or_construct_tearing_state(sys)
745758
if state !== nothing
746759
Base.printstyled(io, "\nIncidence matrix:"; color=:magenta)
747760
show(io, incidence_matrix(state.structure.graph, Num(Sym{Real}(:×))))

src/systems/diffeqs/odesystem.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,18 @@ struct ODESystem <: AbstractODESystem
102102
"""
103103
continuous_events::Vector{SymbolicContinuousCallback}
104104
"""
105+
tearing_state: cache for intermediate tearing state
106+
"""
107+
tearing_state::Any
108+
"""
105109
substitutions: substitutions generated by tearing.
106110
"""
107111
substitutions::Any
108112

109113
function ODESystem(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad,
110114
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
111115
torn_matching, connector_type, connections, preface, events,
112-
substitutions=nothing; checks::Bool=true)
116+
tearing_state=nothing, substitutions=nothing; checks::Bool=true)
113117
if checks
114118
check_variables(dvs,iv)
115119
check_parameters(ps,iv)
@@ -119,7 +123,7 @@ struct ODESystem <: AbstractODESystem
119123
end
120124
new(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac,
121125
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
122-
connector_type, connections, preface, events, substitutions)
126+
connector_type, connections, preface, events, tearing_state, substitutions)
123127
end
124128
end
125129

src/systems/discrete_system/discrete_system.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,6 @@ struct DiscreteSystem <: AbstractTimeDependentSystem
5252
"""
5353
defaults::Dict
5454
"""
55-
structure: structural information of the system
56-
"""
57-
structure::Any
58-
"""
5955
preface: inject assignment statements before the evaluation of the RHS function.
6056
"""
6157
preface::Any
@@ -64,17 +60,21 @@ struct DiscreteSystem <: AbstractTimeDependentSystem
6460
"""
6561
connector_type::Any
6662
"""
63+
tearing_state: cache for intermediate tearing state
64+
"""
65+
tearing_state::Any
66+
"""
6767
substitutions: substitutions generated by tearing.
6868
"""
6969
substitutions::Any
7070

71-
function DiscreteSystem(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, structure, preface, connector_type, substitutions=nothing; checks::Bool = true)
71+
function DiscreteSystem(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, preface, connector_type, tearing_state=nothing, substitutions=nothing; checks::Bool = true)
7272
if checks
7373
check_variables(dvs, iv)
7474
check_parameters(ps, iv)
7575
all_dimensionless([dvs;ps;iv;ctrls]) || check_units(discreteEqs)
7676
end
77-
new(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, structure, preface, connector_type, substitutions)
77+
new(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, preface, connector_type, tearing_state, substitutions)
7878
end
7979
end
8080

@@ -118,7 +118,7 @@ function DiscreteSystem(
118118
if length(unique(sysnames)) != length(sysnames)
119119
throw(ArgumentError("System names must be unique."))
120120
end
121-
DiscreteSystem(eqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, name, systems, defaults, nothing, preface, connector_type, kwargs...)
121+
DiscreteSystem(eqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, name, systems, defaults, preface, connector_type, kwargs...)
122122
end
123123

124124

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,19 @@ struct NonlinearSystem <: AbstractTimeIndependentSystem
5151
"""
5252
connector_type::Any
5353
"""
54+
tearing_state: cache for intermediate tearing state
55+
"""
56+
tearing_state::Any
57+
"""
5458
substitutions: substitutions generated by tearing.
5559
"""
5660
substitutions::Any
5761

58-
function NonlinearSystem(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults, connector_type, substitutions=nothing; checks::Bool = true)
62+
function NonlinearSystem(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults, connector_type, tearing_state=nothing, substitutions=nothing; checks::Bool = true)
5963
if checks
6064
all_dimensionless([states;ps]) ||check_units(eqs)
6165
end
62-
new(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults, connector_type, substitutions)
66+
new(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults, connector_type, tearing_state, substitutions)
6367
end
6468
end
6569

0 commit comments

Comments
 (0)