Skip to content

Commit 7d9e093

Browse files
authored
Merge pull request #1844 from SciML/myb/state_priority
Add `state_priority` to prefer some states in dummy derivative
2 parents fc0c5fd + c66c5bd commit 7d9e093

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

src/structural_transformation/partial_state_selection.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,12 @@ end
159159

160160
function dummy_derivative_graph!(state::TransformationState, jac = nothing,
161161
(ag, diff_va) = (nothing, nothing);
162-
kwargs...)
162+
state_priority = nothing, kwargs...)
163163
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
164164
var_eq_matching = complete(pantelides!(state, ag))
165165
complete!(state.structure)
166-
dummy_derivative_graph!(state.structure, var_eq_matching, jac, (ag, diff_va))
166+
dummy_derivative_graph!(state.structure, var_eq_matching, jac, (ag, diff_va),
167+
state_priority)
167168
end
168169

169170
function compute_diff_level(diff_to_x)
@@ -184,7 +185,7 @@ function compute_diff_level(diff_to_x)
184185
end
185186

186187
function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, jac,
187-
(ag, diff_va))
188+
(ag, diff_va), state_priority)
188189
@unpack eq_to_diff, var_to_diff, graph = structure
189190
diff_to_eq = invview(eq_to_diff)
190191
diff_to_var = invview(var_to_diff)
@@ -204,12 +205,17 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
204205
iszero(maxlevel) && continue
205206

206207
rank_matching = Matching(nvars)
208+
isfirst = true
207209
for _ in maxlevel:-1:1
208210
eqs = filter(eq -> diff_to_eq[eq] !== nothing, eqs)
209211
nrows = length(eqs)
210212
iszero(nrows) && break
211213
eqs_set = BitSet(eqs)
212214

215+
if state_priority !== nothing && isfirst
216+
sort!(vars, by = state_priority)
217+
end
218+
isfirst = false
213219
# TODO: making the algorithm more robust
214220
# 1. If the Jacobian is a integer matrix, use Bareiss to check
215221
# linear independence. (done)

src/structural_transformation/symbolics_tearing.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -648,10 +648,27 @@ Perform index reduction and use the dummy derivative technique to ensure that
648648
the system is balanced.
649649
"""
650650
function dummy_derivative(sys, state = TearingState(sys); simplify = false, kwargs...)
651-
function jac(eqs, vars)
652-
symeqs = EquationsView(state)[eqs]
653-
Symbolics.jacobian((x -> x.rhs).(symeqs), state.fullvars[vars])
651+
jac = let state = state
652+
(eqs, vars) -> begin
653+
symeqs = EquationsView(state)[eqs]
654+
Symbolics.jacobian((x -> x.rhs).(symeqs), state.fullvars[vars])
655+
end
656+
end
657+
state_priority = let state = state
658+
var -> begin
659+
p = 0.0
660+
var_to_diff = state.structure.var_to_diff
661+
diff_to_var = invview(var_to_diff)
662+
while var_to_diff[var] !== nothing
663+
var = var_to_diff[var]
664+
end
665+
while true
666+
p = max(p, ModelingToolkit.state_priority(state.fullvars[var]))
667+
(var = diff_to_var[var]) === nothing && break
668+
end
669+
p
670+
end
654671
end
655-
var_eq_matching = dummy_derivative_graph!(state, jac; kwargs...)
672+
var_eq_matching = dummy_derivative_graph!(state, jac; state_priority, kwargs...)
656673
tearing_reassemble(state, var_eq_matching; simplify = simplify)
657674
end

src/variables.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@ struct VariableNoiseType end
44
struct VariableInput end
55
struct VariableOutput end
66
struct VariableIrreducible end
7+
struct VariableStatePriority end
78
Symbolics.option_to_metadata_type(::Val{:unit}) = VariableUnit
89
Symbolics.option_to_metadata_type(::Val{:connect}) = VariableConnectType
910
Symbolics.option_to_metadata_type(::Val{:noise}) = VariableNoiseType
1011
Symbolics.option_to_metadata_type(::Val{:input}) = VariableInput
1112
Symbolics.option_to_metadata_type(::Val{:output}) = VariableOutput
1213
Symbolics.option_to_metadata_type(::Val{:irreducible}) = VariableIrreducible
14+
Symbolics.option_to_metadata_type(::Val{:state_priority}) = VariableStatePriority
1315

1416
abstract type AbstractConnectType end
1517
struct Equality <: AbstractConnectType end # Equality connection
@@ -26,6 +28,7 @@ end
2628
isinput(x) = isvarkind(VariableInput, x)
2729
isoutput(x) = isvarkind(VariableOutput, x)
2830
isirreducible(x) = isvarkind(VariableIrreducible, x) || isinput(x)
31+
state_priority(x) = convert(Float64, getmetadata(x, VariableStatePriority, 0.0))::Float64
2932

3033
"""
3134
$(SIGNATURES)

0 commit comments

Comments
 (0)