Skip to content

Commit f3b6994

Browse files
committed
Make sure structural_simplify is composable
1 parent e1a36a0 commit f3b6994

File tree

6 files changed

+22
-13
lines changed

6 files changed

+22
-13
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using ModelingToolkit: ODESystem, var_from_nested_derivative, Differential,
1616
states, equations, vars, Symbolic, diff2term, value,
1717
operation, arguments, Sym, Term, simplify, solve_for,
1818
isdiffeq, isdifferential,
19-
get_structure, default_u0, default_p
19+
get_structure, get_reduced_states, default_u0, default_p
2020

2121
using ModelingToolkit.BipartiteGraphs
2222
using ModelingToolkit.SystemStructures

src/structural_transformation/tearing.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ function tearing_reassemble(sys; simplify=false)
180180
@set! sys.structure.algeqs = newalgeqs
181181
@set! sys.eqs = neweqs
182182
@set! sys.states = newstates
183+
@set! sys.reduced_states = [get_reduced_states(sys); solvars]
183184
@set! sys.observed = vcat(observed(sys), obseqs)
184185
return sys
185186
end

src/systems/abstractsystem.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ for prop in [
154154
:inequality_constraints
155155
:controls
156156
:loss
157+
:reduced_states
157158
]
158159
fname1 = Symbol(:get_, prop)
159160
fname2 = Symbol(:has_, prop)
@@ -471,9 +472,9 @@ Structurally simplify algebraic equations in a system and compute the
471472
topological sort of the observed equations.
472473
"""
473474
function structural_simplify(sys::AbstractSystem)
474-
ss = states(sys)
475475
sys = tearing(alias_elimination(sys))
476476
s = structure(sys)
477-
@set! sys.observed = topsort_equations(observed(sys), ss)
477+
fullstates = [get_reduced_states(sys); states(sys)]
478+
@set! sys.observed = topsort_equations(observed(sys), fullstates)
478479
return sys
479480
end

src/systems/alias_elimination.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@ using SymbolicUtils: Rewriters
33
const KEEP = typemin(Int)
44

55
function alias_elimination(sys)
6-
sys = flatten(sys)
7-
s = get_structure(sys)
8-
if !(s isa SystemStructure)
9-
sys = initialize_system_structure(sys)
10-
s = structure(sys)
11-
end
6+
# FIXME: update `structure` too
7+
#sys = flatten(sys)
8+
#s = get_structure(sys)
9+
#if !(s isa SystemStructure)
10+
sys = initialize_system_structure(sys)
11+
s = structure(sys)
12+
#end
1213
is_linear_equations, eadj, cadj = find_linear_equations(sys)
1314

1415
v_eliminated, v_types, n_null_vars, degenerate_equations, linear_equations = alias_eliminate_graph(
@@ -18,9 +19,12 @@ function alias_elimination(sys)
1819
s = structure(sys)
1920
@unpack fullvars, graph = s
2021

22+
n_reduced_states = length(v_eliminated) - n_null_vars
23+
reduced_states = similar(v_eliminated, Any, n_reduced_states)
2124
subs = OrderedDict()
22-
if length(v_eliminated) - n_null_vars > 0
23-
for v in v_eliminated[n_null_vars+1:end]
25+
if n_reduced_states > 0
26+
for (i, v) in enumerate(@view v_eliminated[n_null_vars+1:end])
27+
reduced_states[i] = fullvars[v]
2428
subs[fullvars[v]] = iszeroterm(v_types, v) ? 0.0 :
2529
isalias(v_types, v) ? fullvars[alias(v_types, v)] :
2630
-fullvars[negalias(v_types, v)]
@@ -63,6 +67,7 @@ function alias_elimination(sys)
6367

6468
@set! sys.eqs = eqs
6569
@set! sys.states = newstates
70+
@set! sys.reduced_states = [get_reduced_states(sys); reduced_states]
6671
@set! sys.observed = [get_observed(sys); [lhs ~ rhs for (lhs, rhs) in pairs(subs)]]
6772
@set! sys.structure = nothing
6873
return sys

src/systems/diffeqs/odesystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ struct ODESystem <: AbstractODESystem
7474
structure: structural information of the system
7575
"""
7676
structure::Any
77+
reduced_states::Vector
7778
end
7879

7980
function ODESystem(
@@ -101,7 +102,7 @@ function ODESystem(
101102
if length(unique(sysnames)) != length(sysnames)
102103
throw(ArgumentError("System names must be unique."))
103104
end
104-
ODESystem(deqs, iv′, dvs′, ps′, observed, tgrad, jac, Wfact, Wfact_t, name, systems, default_u0, default_p, nothing)
105+
ODESystem(deqs, iv′, dvs′, ps′, observed, tgrad, jac, Wfact, Wfact_t, name, systems, default_u0, default_p, nothing, [])
105106
end
106107

107108
var_from_nested_derivative(x, i=0) = (missing, missing)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ struct NonlinearSystem <: AbstractSystem
4848
structure: structural information of the system
4949
"""
5050
structure::Any
51+
reduced_states::Any
5152
end
5253

5354
function NonlinearSystem(eqs, states, ps;
@@ -60,7 +61,7 @@ function NonlinearSystem(eqs, states, ps;
6061
default_p isa Dict || (default_p = Dict(default_p))
6162
default_u0 = Dict(value(k) => value(default_u0[k]) for k in keys(default_u0))
6263
default_p = Dict(value(k) => value(default_p[k]) for k in keys(default_p))
63-
NonlinearSystem(eqs, value.(states), value.(ps), observed, name, systems, default_u0, default_p, nothing)
64+
NonlinearSystem(eqs, value.(states), value.(ps), observed, name, systems, default_u0, default_p, nothing, [])
6465
end
6566

6667
function calculate_jacobian(sys::NonlinearSystem;sparse=false,simplify=false)

0 commit comments

Comments
 (0)