Skip to content

Commit d83d00a

Browse files
committed
Merge branch 'master' into myb/state_priority
2 parents b27a1eb + fc0c5fd commit d83d00a

File tree

13 files changed

+151
-46
lines changed

13 files changed

+151
-46
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ModelingToolkit"
22
uuid = "961ee093-0014-501f-94e3-6117800e7a78"
33
authors = ["Yingbo Ma <[email protected]>", "Chris Rackauckas <[email protected]> and contributors"]
4-
version = "8.23.0"
4+
version = "8.24.0"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/bipartite_graph.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,17 @@ end
6363
function Base.setindex!(m::Matching{U}, v::Union{Integer, U}, i::Integer) where {U}
6464
if m.inv_match !== nothing
6565
oldv = m.match[i]
66-
isa(oldv, Int) && (m.inv_match[oldv] = unassigned)
66+
# TODO: maybe default Matching to always have an `inv_match`?
67+
68+
# To maintain the invariant that `m.inv_match[m.match[i]] == i`, we need
69+
# to unassign the matching at `m.inv_match[v]` if it exists.
70+
if v isa Int && (iv = m.inv_match[v]) isa Int
71+
m.match[iv] = unassigned
72+
end
73+
if isa(oldv, Int)
74+
@assert m.inv_match[oldv] == i
75+
m.inv_match[oldv] = unassigned
76+
end
6777
isa(v, Int) && (m.inv_match[v] = i)
6878
end
6979
return m.match[i] = v

src/inputoutput.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ has_var(ex, x) = x ∈ Set(get_variables(ex))
160160
# Build control function
161161

162162
"""
163-
(f_oop, f_ip), dvs, p = generate_control_function(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys); implicit_dae = false, ddvs = if implicit_dae
163+
(f_oop, f_ip), dvs, p = generate_control_function(sys::AbstractODESystem, inputs = unbound_inputs(sys); implicit_dae = false, ddvs = if implicit_dae
164164
165-
For a system `sys` that has unbound inputs (as determined by [`unbound_inputs`](@ref)), generate a function with additional input argument `in`
165+
For a system `sys` with inputs (as determined by [`unbound_inputs`](@ref) or user specified), generate a function with additional input argument `in`
166166
```
167167
f_oop : (u,in,p,t) -> rhs
168168
f_ip : (uout,u,in,p,t) -> nothing
@@ -187,7 +187,7 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
187187
error("No unbound inputs were found in system.")
188188
end
189189

190-
sys, diff_idxs, alge_idxs = io_preprocessing(sys, inputs, []; simplify, kwargs...)
190+
sys, _ = io_preprocessing(sys, inputs, []; simplify, kwargs...)
191191

192192
dvs = states(sys)
193193
ps = parameters(sys)
@@ -269,8 +269,8 @@ function inputs_to_parameters!(state::TransformationState, io)
269269
@assert new_v > 0
270270
new_var_to_diff[new_i] = new_v
271271
end
272-
@set! structure.var_to_diff = new_var_to_diff
273-
@set! structure.graph = new_graph
272+
@set! structure.var_to_diff = complete(new_var_to_diff)
273+
@set! structure.graph = complete(new_graph)
274274

275275
@set! sys.eqs = map(Base.Fix2(substitute, input_to_parameters), equations(sys))
276276
@set! sys.states = setdiff(states(sys), keys(input_to_parameters))

src/structural_transformation/StructuralTransformations.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
2121
ExtraVariablesSystemException,
2222
get_postprocess_fbody, vars!,
2323
IncrementalCycleTracker, add_edge_checked!, topological_sort,
24-
invalidate_cache!, Substitutions, get_or_construct_tearing_state
24+
invalidate_cache!, Substitutions, get_or_construct_tearing_state,
25+
AliasGraph
2526

2627
using ModelingToolkit.BipartiteGraphs
2728
import .BipartiteGraphs: invview

src/structural_transformation/pantelides.jl

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,47 @@ end
7474
7575
Perform Pantelides algorithm.
7676
"""
77-
function pantelides!(state::TransformationState; maxiters = 8000)
77+
function pantelides!(state::TransformationState, ag::Union{AliasGraph, Nothing} = nothing;
78+
maxiters = 8000)
7879
@unpack graph, solvable_graph, var_to_diff, eq_to_diff = state.structure
7980
neqs = nsrcs(graph)
8081
nvars = nv(var_to_diff)
8182
vcolor = falses(nvars)
8283
ecolor = falses(neqs)
8384
var_eq_matching = Matching(nvars)
8485
neqs′ = neqs
86+
nnonemptyeqs = count(eq -> !isempty(𝑠neighbors(graph, eq)) && eq_to_diff[eq] === nothing,
87+
1:neqs′)
88+
89+
# Allow matching for the highest differentiated variable that
90+
# currently appears in an equation (or implicit equation in a side ag)
91+
varwhitelist = falses(nvars)
92+
for var in 1:nvars
93+
if var_to_diff[var] === nothing && !varwhitelist[var]
94+
while isempty(𝑑neighbors(graph, var)) && (ag === nothing || !haskey(ag, var))
95+
var′ = invview(var_to_diff)[var]
96+
var′ === nothing && break
97+
var = var′
98+
end
99+
if !isempty(𝑑neighbors(graph, var))
100+
if ag !== nothing && haskey(ag, var)
101+
# TODO: remove lower diff vars from whitelist
102+
c, a = ag[var]
103+
iszero(c) || (varwhitelist[a] = true)
104+
else
105+
varwhitelist[var] = true
106+
end
107+
end
108+
end
109+
end
110+
111+
if nnonemptyeqs > count(varwhitelist)
112+
throw(InvalidSystemException("System is structurally singular"))
113+
end
114+
85115
for k in 1:neqs′
86116
eq′ = k
117+
eq_to_diff[eq′] === nothing || continue
87118
isempty(𝑠neighbors(graph, eq′)) && continue
88119
pathfound = false
89120
# In practice, `maxiters=8000` should never be reached, otherwise, the
@@ -93,7 +124,6 @@ function pantelides!(state::TransformationState; maxiters = 8000)
93124
#
94125
# the derivatives and algebraic variables are zeros in the variable
95126
# association list
96-
varwhitelist = var_to_diff .== nothing
97127
resize!(vcolor, nvars)
98128
fill!(vcolor, false)
99129
resize!(ecolor, neqs)
@@ -103,11 +133,16 @@ function pantelides!(state::TransformationState; maxiters = 8000)
103133
pathfound && break # terminating condition
104134
for var in eachindex(vcolor)
105135
vcolor[var] || continue
106-
# introduce a new variable
107-
nvars += 1
108-
var_diff = var_derivative!(state, var)
109-
push!(var_eq_matching, unassigned)
110-
@assert length(var_eq_matching) == var_diff
136+
if var_to_diff[var] === nothing
137+
# introduce a new variable
138+
nvars += 1
139+
var_diff = var_derivative!(state, var)
140+
push!(var_eq_matching, unassigned)
141+
push!(varwhitelist, false)
142+
@assert length(var_eq_matching) == var_diff
143+
end
144+
varwhitelist[var] = false
145+
varwhitelist[var_to_diff[var]] = true
111146
end
112147

113148
for eq in eachindex(ecolor)

src/structural_transformation/partial_state_selection.jl

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
function partial_state_selection_graph!(state::TransformationState)
1+
function partial_state_selection_graph!(state::TransformationState;
2+
ag::Union{AliasGraph, Nothing} = nothing)
23
find_solvables!(state; allow_symbolic = true)
3-
var_eq_matching = complete(pantelides!(state))
4+
var_eq_matching = complete(pantelides!(state, ag))
45
complete!(state.structure)
56
partial_state_selection_graph!(state.structure, var_eq_matching)
67
end
@@ -86,6 +87,9 @@ function pss_graph_modia!(structure::SystemStructure, var_eq_matching, varlevel,
8687
filter!(eq -> invview(ict.graph.matching)[eq] === unassigned, to_tear_eqs)
8788
tearEquations!(ict, solvable_graph.fadjlist, to_tear_eqs, BitSet(to_tear_vars),
8889
nothing)
90+
for var in to_tear_vars
91+
var_eq_matching[var] = unassigned
92+
end
8993
for var in to_tear_vars
9094
var_eq_matching[var] = ict.graph.matching[var]
9195
end
@@ -94,7 +98,10 @@ function pss_graph_modia!(structure::SystemStructure, var_eq_matching, varlevel,
9498
end
9599
end
96100
for var in 1:ndsts(graph)
97-
if varlevel[var] !== 0 && var_eq_matching[var] === unassigned
101+
dv = var_to_diff[var]
102+
# If `var` is not algebraic (not differentiated nor a dummy derivative),
103+
# then it's a SelectedState
104+
if !(dv === nothing || (varlevel[dv] !== 0 && var_eq_matching[dv] === unassigned))
98105
var_eq_matching[var] = SelectedState()
99106
end
100107
end
@@ -116,12 +123,15 @@ function partial_state_selection_graph!(structure::SystemStructure, var_eq_match
116123
end
117124

118125
varlevel = map(1:ndsts(graph)) do var
119-
level = 0
126+
graph_level = level = 0
120127
while var_to_diff[var] !== nothing
121128
var = var_to_diff[var]
122129
level += 1
130+
if !isempty(𝑑neighbors(graph, var))
131+
graph_level = level
132+
end
123133
end
124-
level
134+
graph_level
125135
end
126136

127137
inv_varlevel = map(1:ndsts(graph)) do var
@@ -135,7 +145,7 @@ function partial_state_selection_graph!(structure::SystemStructure, var_eq_match
135145

136146
# TODO: Should pantelides just return this?
137147
for var in 1:ndsts(graph)
138-
if var_to_diff[var] !== nothing
148+
if varlevel[var] !== 0
139149
var_eq_matching[var] = unassigned
140150
end
141151
end
@@ -147,11 +157,13 @@ function partial_state_selection_graph!(structure::SystemStructure, var_eq_match
147157
var_eq_matching
148158
end
149159

150-
function dummy_derivative_graph!(state::TransformationState, jac = nothing, state_priority = nothing; kwargs...)
160+
function dummy_derivative_graph!(state::TransformationState, jac = nothing,
161+
(ag, diff_va) = (nothing, nothing);
162+
state_priority = nothing, kwargs...)
151163
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
152-
var_eq_matching = complete(pantelides!(state))
164+
var_eq_matching = complete(pantelides!(state, ag))
153165
complete!(state.structure)
154-
dummy_derivative_graph!(state.structure, var_eq_matching, jac, state_priority)
166+
dummy_derivative_graph!(state.structure, var_eq_matching, jac, (ag, diff_va), state_priority)
155167
end
156168

157169
function compute_diff_level(diff_to_x)
@@ -171,7 +183,8 @@ function compute_diff_level(diff_to_x)
171183
return xlevel, maxlevel
172184
end
173185

174-
function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, jac, state_priority)
186+
function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, jac,
187+
(ag, diff_va), state_priority)
175188
@unpack eq_to_diff, var_to_diff, graph = structure
176189
diff_to_eq = invview(eq_to_diff)
177190
diff_to_var = invview(var_to_diff)
@@ -240,6 +253,18 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
240253
vars = [diff_to_var[var] for var in vars if diff_to_var[var] !== nothing]
241254
end
242255
end
256+
if diff_va !== nothing
257+
n_dummys = length(dummy_derivatives)
258+
needed = count(x -> x isa Int, diff_to_eq) - n_dummys
259+
n = 0
260+
for v in diff_va
261+
c, a = ag[v]
262+
n += 1
263+
push!(dummy_derivatives, iszero(c) ? v : a)
264+
needed == n && break
265+
continue
266+
end
267+
end
243268

244269
dummy_derivatives_set = BitSet(dummy_derivatives)
245270
# We can eliminate variables that are not a selected state (differential
@@ -249,6 +274,9 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
249274
dummy_derivatives_set = dummy_derivatives_set
250275

251276
v -> begin
277+
if ag !== nothing
278+
haskey(ag, v) && return false
279+
end
252280
dv = var_to_diff[v]
253281
dv === nothing || dv in dummy_derivatives_set
254282
end
@@ -264,7 +292,8 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
264292
Union{Unassigned, SelectedState};
265293
varfilter = can_eliminate)
266294
for v in eachindex(var_eq_matching)
267-
can_eliminate(v) && continue
295+
dv = var_to_diff[v]
296+
(dv === nothing || dv in dummy_derivatives_set) && continue
268297
var_eq_matching[v] = SelectedState()
269298
end
270299

src/structural_transformation/symbolics_tearing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
567567
eq_to_diff = new_eq_to_diff
568568
diff_to_var = invview(var_to_diff)
569569

570-
@set! state.structure.graph = graph
570+
@set! state.structure.graph = complete(graph)
571571
@set! state.structure.var_to_diff = var_to_diff
572572
@set! state.structure.eq_to_diff = eq_to_diff
573573
@set! state.fullvars = fullvars = fullvars[invvarsperm]
@@ -669,6 +669,6 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false, kwar
669669
p
670670
end
671671
end
672-
var_eq_matching = dummy_derivative_graph!(state, jac, state_priority; kwargs...)
672+
var_eq_matching = dummy_derivative_graph!(state, jac; state_priority, kwargs...)
673673
tearing_reassemble(state, var_eq_matching; simplify = simplify)
674674
end

src/systems/abstractsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ function namespace_assignment(eq::Assignment, sys)
405405
Assignment(_lhs, _rhs)
406406
end
407407

408-
function namespace_expr(O, sys, n = nameof(sys)) where {T}
408+
function namespace_expr(O, sys, n = nameof(sys))
409409
ivs = independent_variables(sys)
410410
O = unwrap(O)
411411
if any(isequal(O), ivs)

src/systems/alias_elimination.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ const KEEP = typemin(Int)
66

77
function alias_eliminate_graph!(state::TransformationState)
88
mm = linear_subsys_adjmat(state)
9-
size(mm, 1) == 0 && return AliasGraph(ndsts(state.structure.graph)), mm, BitSet() # No linear subsystems
9+
if size(mm, 1) == 0
10+
ag = AliasGraph(ndsts(state.structure.graph))
11+
return ag, mm, ag, mm, BitSet() # No linear subsystems
12+
end
1013

1114
@unpack graph, var_to_diff = state.structure
1215

@@ -39,7 +42,7 @@ alias_elimination(sys) = alias_elimination!(TearingState(sys; quick_cancel = tru
3942
function alias_elimination!(state::TearingState)
4043
sys = state.sys
4144
complete!(state.structure)
42-
ag, mm, updated_diff_vars = alias_eliminate_graph!(state)
45+
ag, mm, complete_ag, complete_mm, updated_diff_vars = alias_eliminate_graph!(state)
4346
isempty(ag) && return sys
4447

4548
fullvars = state.fullvars
@@ -543,6 +546,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
543546
#
544547
nvars = ndsts(graph)
545548
ag = AliasGraph(nvars)
549+
complete_ag = AliasGraph(nvars)
546550
mm, echelon_mm = simple_aliases!(ag, graph, var_to_diff, mm_orig)
547551

548552
# Step 3: Handle differentiated variables
@@ -702,6 +706,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
702706
while (iv = diff_to_var[v]) in zero_vars_set
703707
v = iv
704708
end
709+
complete_ag[v] = 0
705710
if diff_to_var[v] === nothing # `v` is reducible
706711
dag[v] = 0
707712
end
@@ -729,6 +734,13 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
729734
# Step 4: Merge dag and ag
730735
removed_aliases = BitSet()
731736
merged_ag = AliasGraph(nvars)
737+
for (v, (c, a)) in dag
738+
complete_ag[v] = c => a
739+
end
740+
for (v, (c, a)) in ag
741+
(processed[v] || (!iszero(a) && processed[a])) && continue
742+
complete_ag[v] = c => a
743+
end
732744
for (v, (c, a)) in dag
733745
# D(x) ~ D(y) cannot be removed if x and y are not aliases
734746
if v != a && !iszero(a)
@@ -789,7 +801,8 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
789801
update_graph_neighbors!(graph, ag)
790802
end
791803

792-
return ag, mm, updated_diff_vars
804+
complete_mm = reduce!(copy(echelon_mm), mm_orig, complete_ag, size(echelon_mm, 1))
805+
return ag, mm, complete_ag, complete_mm, updated_diff_vars
793806
end
794807

795808
function update_graph_neighbors!(graph, ag)

0 commit comments

Comments
 (0)