Skip to content

Commit b77ed41

Browse files
feat: allow simplifying DAEs to index zero
1 parent eda23d4 commit b77ed41

File tree

3 files changed

+66
-13
lines changed

3 files changed

+66
-13
lines changed

src/structural_transformation/pantelides.jl

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ function pantelides_reassemble(state::TearingState, var_eq_matching)
7171
end
7272

7373
"""
74-
computed_highest_diff_variables(structure)
74+
computed_highest_diff_variables(structure; whitelisted_vars = ())
7575
7676
Computes which variables are the "highest-differentiated" for purposes of
7777
pantelides. Ordinarily this is relatively straightforward. However, in our
@@ -83,12 +83,18 @@ case, there is one complicating condition:
8383
8484
This function takes care of these complications are returns a boolean array
8585
for every variable, indicating whether it is considered "highest-differentiated".
86+
87+
For each index `i` in `whitelisted_vars`, the `i`th variable is included if it
88+
is the highest differentiated variable even if it doesn't appear in the system.
8689
"""
87-
function computed_highest_diff_variables(structure)
90+
function computed_highest_diff_variables(structure; whitelisted_vars = ())
8891
@unpack graph, var_to_diff = structure
8992

9093
nvars = length(var_to_diff)
9194
varwhitelist = falses(nvars)
95+
for i in whitelisted_vars
96+
varwhitelist[i] = true
97+
end
9298
for var in 1:nvars
9399
if var_to_diff[var] === nothing && !varwhitelist[var]
94100
# This variable is structurally highest-differentiated, but may not actually appear in the
@@ -125,7 +131,7 @@ end
125131
Perform Pantelides algorithm.
126132
"""
127133
function pantelides!(
128-
state::TransformationState; finalize = true, maxiters = 8000, kwargs...)
134+
state::TransformationState; finalize = true, maxiters = 8000, whitelisted_vars = (), kwargs...)
129135
@unpack graph, solvable_graph, var_to_diff, eq_to_diff = state.structure
130136
neqs = nsrcs(graph)
131137
nvars = nv(var_to_diff)
@@ -137,8 +143,7 @@ function pantelides!(
137143
eq -> !isempty(𝑠neighbors(graph, eq)) && eq_to_diff[eq] === nothing,
138144
1:neqs′)
139145

140-
varwhitelist = computed_highest_diff_variables(state.structure)
141-
146+
varwhitelist = computed_highest_diff_variables(state.structure; whitelisted_vars)
142147
if nnonemptyeqs > count(varwhitelist)
143148
throw(InvalidSystemException("System is structurally singular"))
144149
end
@@ -206,14 +211,19 @@ function pantelides!(
206211
end
207212

208213
"""
209-
dae_index_lowering(sys::ODESystem; kwargs...) -> ODESystem
214+
dae_index_lowering(sys::ODESystem; to_index_zero = false, kwargs...) -> ODESystem
210215
211216
Perform the Pantelides algorithm to transform a higher index DAE to an index 1
212217
DAE. `kwargs` are forwarded to [`pantelides!`](@ref). End users are encouraged to call [`structural_simplify`](@ref)
213-
instead, which calls this function internally.
218+
instead, which calls this function internally. If `to_index_zero` is true, the DAE will be reduced to an index 1 DAE.
214219
"""
215-
function dae_index_lowering(sys::ODESystem; kwargs...)
220+
function dae_index_lowering(sys::ODESystem; to_index_zero = false, kwargs...)
216221
state = TearingState(sys)
217-
var_eq_matching = pantelides!(state; finalize = false, kwargs...)
222+
if to_index_zero
223+
newvars = ModelingToolkit.add_missing_differentials!(state)
224+
else
225+
newvars = ()
226+
end
227+
var_eq_matching = pantelides!(state; finalize = false, whitelisted_vars = newvars, kwargs...)
218228
return invalidate_cache!(pantelides_reassemble(state, var_eq_matching))
219229
end

src/structural_transformation/symbolics_tearing.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,13 @@ Perform index reduction and use the dummy derivative technique to ensure that
812812
the system is balanced.
813813
"""
814814
function dummy_derivative(sys, state = TearingState(sys); simplify = false,
815-
mm = nothing, cse_hack = true, array_hack = true, kwargs...)
815+
mm = nothing, cse_hack = true, array_hack = true, to_index_zero = false, kwargs...)
816+
if to_index_zero
817+
newvars = ModelingToolkit.add_missing_differentials!(state)
818+
else
819+
newvars = ()
820+
end
821+
816822
jac = let state = state
817823
(eqs, vars) -> begin
818824
symeqs = EquationsView(state)[eqs]
@@ -834,7 +840,7 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
834840
p
835841
end
836842
end
837-
var_eq_matching = dummy_derivative_graph!(state, jac; state_priority,
843+
var_eq_matching = dummy_derivative_graph!(state, jac; state_priority, whitelisted_vars = newvars,
838844
kwargs...)
839845
tearing_reassemble(state, var_eq_matching; simplify, mm, cse_hack, array_hack)
840846
end

src/systems/systemstructure.jl

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,37 @@ function shift_discrete_system(ts::TearingState)
480480
return ts
481481
end
482482

483+
"""
484+
$(TYPEDSIGNATURES)
485+
486+
For each variable in `ts.fullvars` which does not have a derivative in `ts.fullvars`
487+
and is not the derivative of a variable in `ts.fullvars`, add its derivative to `ts`.
488+
Returns the indexes of added differential variables.
489+
"""
490+
function add_missing_differentials!(ts::TearingState)
491+
sys = ts.sys
492+
D = Differential(get_iv(sys))
493+
newvars = Int[]
494+
for (i, v) in enumerate(ts.fullvars)
495+
# ignore variables that have a derivative...
496+
ts.structure.var_to_diff[i] === nothing || continue
497+
# or are the derivative
498+
invview(ts.structure.var_to_diff)[i] === nothing || continue
499+
# add to fullvars
500+
push!(ts.fullvars, D(v))
501+
push!(newvars, length(ts.fullvars))
502+
# update diffgraph
503+
add_vertex!(ts.structure.var_to_diff)
504+
add_edge!(ts.structure.var_to_diff, i, length(ts.fullvars))
505+
# update bipartite graphs
506+
add_vertex!(ts.structure.graph, DST)
507+
if ts.structure.solvable_graph !== nothing
508+
add_vertex!(ts.structure.solvable_graph, DST)
509+
end
510+
end
511+
return newvars
512+
end
513+
483514
using .BipartiteGraphs: Label, BipartiteAdjacencyList
484515
struct SystemStructurePrintMatrix <:
485516
AbstractMatrix{Union{Label, BipartiteAdjacencyList}}
@@ -676,6 +707,7 @@ end
676707
function _structural_simplify!(state::TearingState, io; simplify = false,
677708
check_consistency = true, fully_determined = true, warn_initialize_determined = false,
678709
dummy_derivative = true,
710+
to_index_zero = false,
679711
kwargs...)
680712
if fully_determined isa Bool
681713
check_consistency &= fully_determined
@@ -699,9 +731,14 @@ function _structural_simplify!(state::TearingState, io; simplify = false,
699731
end
700732
if fully_determined && dummy_derivative
701733
sys = ModelingToolkit.dummy_derivative(
702-
sys, state; simplify, mm, check_consistency, kwargs...)
734+
sys, state; simplify, mm, check_consistency, to_index_zero, kwargs...)
703735
elseif fully_determined
704-
var_eq_matching = pantelides!(state; finalize = false, kwargs...)
736+
if to_index_zero
737+
newvars = add_missing_differentials!(state)
738+
else
739+
newvars = ()
740+
end
741+
var_eq_matching = pantelides!(state; finalize = false, whitelisted_vars = newvars, kwargs...)
705742
sys = pantelides_reassemble(state, var_eq_matching)
706743
state = TearingState(sys)
707744
sys, mm = ModelingToolkit.alias_elimination!(state; kwargs...)

0 commit comments

Comments
 (0)