Skip to content

Commit f120310

Browse files
Merge pull request #3235 from AayushSabharwal/as/init-fully-determined
feat: simplify initialization systems with `fully_determined=true` if possible
2 parents 42d4d63 + 2b7a6b6 commit f120310

File tree

6 files changed

+76
-21
lines changed

6 files changed

+76
-21
lines changed

src/structural_transformation/utils.jl

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,41 @@ end
5858
###
5959
### Structural check
6060
###
61-
function check_consistency(state::TransformationState, orig_inputs)
61+
62+
"""
63+
$(TYPEDSIGNATURES)
64+
65+
Check if the `state` represents a singular system, and return the unmatched variables.
66+
"""
67+
function singular_check(state::TransformationState)
68+
@unpack graph, var_to_diff = state.structure
69+
fullvars = get_fullvars(state)
70+
# This is defined to check if Pantelides algorithm terminates. For more
71+
# details, check the equation (15) of the original paper.
72+
extended_graph = (@set graph.fadjlist = Vector{Int}[graph.fadjlist;
73+
map(collect, edges(var_to_diff))])
74+
extended_var_eq_matching = maximal_matching(extended_graph)
75+
76+
nvars = ndsts(graph)
77+
unassigned_var = []
78+
for (vj, eq) in enumerate(extended_var_eq_matching)
79+
vj > nvars && break
80+
if eq === unassigned && !isempty(𝑑neighbors(graph, vj))
81+
push!(unassigned_var, fullvars[vj])
82+
end
83+
end
84+
return unassigned_var
85+
end
86+
87+
"""
88+
$(TYPEDSIGNATURES)
89+
90+
Check the consistency of `state`, given the inputs `orig_inputs`. If `nothrow == false`,
91+
throws an error if the system is under-/over-determined or singular. In this case, if the
92+
function returns it will return `true`. If `nothrow == true`, it will return `false`
93+
instead of throwing an error. The singular case will print a warning.
94+
"""
95+
function check_consistency(state::TransformationState, orig_inputs; nothrow = false)
6296
fullvars = get_fullvars(state)
6397
neqs = n_concrete_eqs(state)
6498
@unpack graph, var_to_diff = state.structure
@@ -72,6 +106,7 @@ function check_consistency(state::TransformationState, orig_inputs)
72106
is_balanced = n_highest_vars == neqs
73107

74108
if neqs > 0 && !is_balanced
109+
nothrow && return false
75110
varwhitelist = var_to_diff .== nothing
76111
var_eq_matching = maximal_matching(graph, eq -> true, v -> varwhitelist[v]) # not assigned
77112
# Just use `error_reporting` to do conditional
@@ -85,22 +120,12 @@ function check_consistency(state::TransformationState, orig_inputs)
85120
error_reporting(state, bad_idxs, n_highest_vars, iseqs, orig_inputs)
86121
end
87122

88-
# This is defined to check if Pantelides algorithm terminates. For more
89-
# details, check the equation (15) of the original paper.
90-
extended_graph = (@set graph.fadjlist = Vector{Int}[graph.fadjlist;
91-
map(collect, edges(var_to_diff))])
92-
extended_var_eq_matching = maximal_matching(extended_graph)
93-
94-
nvars = ndsts(graph)
95-
unassigned_var = []
96-
for (vj, eq) in enumerate(extended_var_eq_matching)
97-
vj > nvars && break
98-
if eq === unassigned && !isempty(𝑑neighbors(graph, vj))
99-
push!(unassigned_var, fullvars[vj])
100-
end
101-
end
123+
unassigned_var = singular_check(state)
102124

103125
if !isempty(unassigned_var) || !is_balanced
126+
if nothrow
127+
return false
128+
end
104129
io = IOBuffer()
105130
Base.print_array(io, unassigned_var)
106131
unassigned_var_str = String(take!(io))
@@ -110,7 +135,7 @@ function check_consistency(state::TransformationState, orig_inputs)
110135
throw(InvalidSystemException(errmsg))
111136
end
112137

113-
return nothing
138+
return true
114139
end
115140

116141
###

src/systems/abstractsystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3335,7 +3335,8 @@ function parse_variable(sys::AbstractSystem, str::AbstractString)
33353335
# I'd write a regex to validate `str`, but https://xkcd.com/1171/
33363336
str = strip(str)
33373337
derivative_level = 0
3338-
while ((cond1 = startswith(str, "D(")) || startswith(str, "Differential(")) && endswith(str, ")")
3338+
while ((cond1 = startswith(str, "D(")) || startswith(str, "Differential(")) &&
3339+
endswith(str, ")")
33393340
if cond1
33403341
derivative_level += 1
33413342
str = _string_view_inner(str, 2, 1)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1295,7 +1295,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
12951295
check_length = true,
12961296
warn_initialize_determined = true,
12971297
initialization_eqs = [],
1298-
fully_determined = false,
1298+
fully_determined = nothing,
12991299
check_units = true,
13001300
kwargs...) where {iip, specialize}
13011301
if !iscomplete(sys)
@@ -1313,6 +1313,19 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
13131313
sys; u0map, initialization_eqs, check_units, pmap = parammap); fully_determined)
13141314
end
13151315

1316+
ts = get_tearing_state(isys)
1317+
if warn_initialize_determined &&
1318+
(unassigned_vars = StructuralTransformations.singular_check(ts); !isempty(unassigned_vars))
1319+
errmsg = """
1320+
The initialization system is structurally singular. Guess values may \
1321+
significantly affect the initial values of the ODE. The problematic variables \
1322+
are $unassigned_vars.
1323+
1324+
Note that the identification of problematic variables is a best-effort heuristic.
1325+
"""
1326+
@warn errmsg
1327+
end
1328+
13161329
uninit = setdiff(unknowns(sys), [unknowns(isys); getfield.(observed(isys), :lhs)])
13171330

13181331
# TODO: throw on uninitialized arrays

src/systems/problem_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ function process_SciMLProblem(
498498
constructor, sys::AbstractSystem, u0map, pmap; build_initializeprob = true,
499499
implicit_dae = false, t = nothing, guesses = AnyDict(),
500500
warn_initialize_determined = true, initialization_eqs = [],
501-
eval_expression = false, eval_module = @__MODULE__, fully_determined = false,
501+
eval_expression = false, eval_module = @__MODULE__, fully_determined = nothing,
502502
check_initialization_units = false, tofloat = true, use_union = false,
503503
u0_constructor = identity, du0map = nothing, check_length = true,
504504
symbolic_u0 = false, warn_cyclic_dependency = false,

src/systems/systemstructure.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,11 @@ function _structural_simplify!(state::TearingState, io; simplify = false,
677677
check_consistency = true, fully_determined = true, warn_initialize_determined = false,
678678
dummy_derivative = true,
679679
kwargs...)
680-
check_consistency &= fully_determined
680+
if fully_determined isa Bool
681+
check_consistency &= fully_determined
682+
else
683+
check_consistency = true
684+
end
681685
has_io = io !== nothing
682686
orig_inputs = Set()
683687
if has_io
@@ -690,7 +694,8 @@ function _structural_simplify!(state::TearingState, io; simplify = false,
690694
end
691695
sys, mm = ModelingToolkit.alias_elimination!(state; kwargs...)
692696
if check_consistency
693-
ModelingToolkit.check_consistency(state, orig_inputs)
697+
fully_determined = ModelingToolkit.check_consistency(
698+
state, orig_inputs; nothrow = fully_determined === nothing)
694699
end
695700
if fully_determined && dummy_derivative
696701
sys = ModelingToolkit.dummy_derivative(

test/initializationsystem.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,3 +947,14 @@ end
947947

948948
@test_nowarn remake(prob, p = prob.p)
949949
end
950+
951+
@testset "Singular initialization prints a warning" begin
952+
@parameters g
953+
@variables x(t) y(t) [state_priority = 10] λ(t)
954+
eqs = [D(D(x)) ~ λ * x
955+
D(D(y)) ~ λ * y - g
956+
x^2 + y^2 ~ 1]
957+
@mtkbuild pend = ODESystem(eqs, t)
958+
@test_warn ["structurally singular", "initialization", "Guess", "heuristic"] ODEProblem(
959+
pend, [x => 1, y => 0], (0.0, 1.5), [g => 1], guesses ==> 1])
960+
end

0 commit comments

Comments
 (0)