Skip to content

Commit d27b1fc

Browse files
feat: add circular dependency checking, substitution limit
1 parent c45d100 commit d27b1fc

File tree

2 files changed

+117
-6
lines changed

2 files changed

+117
-6
lines changed

src/systems/problem_utils.jl

Lines changed: 102 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,21 @@ function add_parameter_dependencies!(sys::AbstractSystem, varmap::AbstractDict)
250250
add_observed_equations!(varmap, parameter_dependencies(sys))
251251
end
252252

253+
struct UnexpectedSymbolicValueInVarmap <: Exception
254+
sym::Any
255+
val::Any
256+
end
257+
258+
function Base.showerror(io::IO, err::UnexpectedSymbolicValueInVarmap)
259+
println(io,
260+
"""
261+
Found symbolic value $(err.val) for variable $(err.sym). You may be missing an \
262+
initial condition or have cyclic initial conditions. If this is intended, pass \
263+
`symbolic_u0 = true`. In case the initial conditions are not cyclic but \
264+
require more substitutions to resolve, increase `substitution_limit`.
265+
""")
266+
end
267+
253268
"""
254269
$(TYPEDSIGNATURES)
255270
@@ -278,6 +293,12 @@ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector;
278293
isempty(missing_vars) || throw(MissingVariablesError(missing_vars))
279294
end
280295
vals = map(x -> varmap[x], vars)
296+
if !allow_symbolic
297+
for (sym, val) in zip(vars, vals)
298+
symbolic_type(val) == NotSymbolic() && continue
299+
throw(UnexpectedSymbolicValueInVarmap(sym, val))
300+
end
301+
end
281302

282303
if container_type <: Union{AbstractDict, Tuple, Nothing, SciMLBase.NullParameters}
283304
container_type = Array
@@ -298,17 +319,83 @@ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector;
298319
end
299320
end
300321

322+
struct SubstitutionCycleError <: Exception
323+
required_variables::Vector{Any}
324+
cycle::Vector{Any}
325+
rules::Vector{Any}
326+
end
327+
328+
function Base.showerror(io::IO, err::SubstitutionCycleError)
329+
println(io, "Detected a cyclic initial condition!")
330+
println(io, "Variables being initialized: ", err.required_variables)
331+
println(io)
332+
println(io, "Variables in cycle: ", err.cycle)
333+
println(io)
334+
println(io, "Substitution rules:")
335+
for rule in err.rules
336+
println(io, rule)
337+
end
338+
println(io,
339+
"Pass `check_cyclic_dependency = false` to disable this check if you think this is an error.")
340+
end
341+
342+
"""
343+
$(TYPEDSIGNATURES)
344+
345+
Check if any of the substitution rules in `varmap` lead to cycles involving
346+
variables in `vars`. Throw an error if such a cycle is found.
347+
"""
348+
function check_substitution_cycles(varmap::AbstractDict, vars)
349+
# ordered set so that `vars` are the first `k` in the list
350+
allvars = OrderedSet(vars)
351+
union!(allvars, keys(varmap))
352+
allvars = collect(allvars)
353+
var_to_idx = Dict(allvars .=> eachindex(allvars))
354+
graph = SimpleDiGraph(length(allvars))
355+
356+
buffer = Set()
357+
for (k, v) in varmap
358+
kidx = var_to_idx[k]
359+
if symbolic_type(v) != NotSymbolic()
360+
vars!(buffer, v)
361+
for var in buffer
362+
haskey(var_to_idx, var) || continue
363+
add_edge!(graph, kidx, var_to_idx[var])
364+
end
365+
elseif v isa AbstractArray
366+
for val in v
367+
vars!(buffer, val)
368+
end
369+
for var in buffer
370+
haskey(var_to_idx, var) || continue
371+
add_edge!(graph, kidx, var_to_idx[var])
372+
end
373+
end
374+
empty!(buffer)
375+
end
376+
377+
# detect at most 100 cycles involving at most `length(varmap)` vertices
378+
cycles = Graphs.simplecycles_limited_length(graph, length(varmap), 100)
379+
# only count those which contain variables in `vars`
380+
filter!(Base.Fix1(any, <=(length(vars))), cycles)
381+
382+
if !isempty(cycles)
383+
cyclevars = allvars[cycles[1]]
384+
throw(SubstitutionCycleError(vars, cyclevars, [v => varmap[v] for v in cyclevars]))
385+
end
386+
end
387+
301388
"""
302389
$(TYPEDSIGNATURES)
303390
304391
Performs symbolic substitution on the values in `varmap` for the keys in `vars`, using
305392
`varmap` itself as the set of substitution rules. If an entry in `vars` is not a key
306393
in `varmap`, it is ignored.
307394
"""
308-
function evaluate_varmap!(varmap::AbstractDict, vars)
395+
function evaluate_varmap!(varmap::AbstractDict, vars; limit = 100)
309396
for k in vars
310397
haskey(varmap, k) || continue
311-
varmap[k] = fixpoint_sub(varmap[k], varmap)
398+
varmap[k] = fixpoint_sub(varmap[k], varmap; maxiters = limit)
312399
end
313400
end
314401

@@ -407,6 +494,10 @@ Keyword arguments:
407494
length of `u0` vector for consistency. If `false`, do not check with equations. This is
408495
forwarded to `check_eqs_u0`
409496
- `symbolic_u0` allows the returned `u0` to be an array of symbolics.
497+
- `check_cyclic_dependency`: Whether to check the operating point for cyclic
498+
substitution rules involving unknowns/parameters.
499+
- `substitution_limit`: The number times to substitute initial conditions into each
500+
other to attempt to arrive at a numeric value.
410501
411502
All other keyword arguments are passed as-is to `constructor`.
412503
"""
@@ -416,7 +507,9 @@ function process_SciMLProblem(
416507
warn_initialize_determined = true, initialization_eqs = [],
417508
eval_expression = false, eval_module = @__MODULE__, fully_determined = false,
418509
check_initialization_units = false, tofloat = true, use_union = false,
419-
u0_constructor = identity, du0map = nothing, check_length = true, symbolic_u0 = false, kwargs...)
510+
u0_constructor = identity, du0map = nothing, check_length = true,
511+
symbolic_u0 = false, check_cyclic_dependency = true,
512+
substitution_limit = length(observed(sys)), kwargs...)
420513
dvs = unknowns(sys)
421514
ps = parameters(sys)
422515
iv = has_iv(sys) ? get_iv(sys) : nothing
@@ -466,7 +559,7 @@ function process_SciMLProblem(
466559
initializeprob = ModelingToolkit.InitializationProblem(
467560
sys, t, u0map, pmap; guesses, warn_initialize_determined,
468561
initialization_eqs, eval_expression, eval_module, fully_determined,
469-
check_units = check_initialization_units)
562+
check_units = check_initialization_units, check_cyclic_dependency)
470563
initializeprobmap = getu(initializeprob, unknowns(sys))
471564

472565
punknowns = [p
@@ -503,7 +596,8 @@ function process_SciMLProblem(
503596
add_observed!(sys, op)
504597
add_parameter_dependencies!(sys, op)
505598

506-
evaluate_varmap!(op, dvs)
599+
check_cyclic_dependency && check_substitution_cycles(op, dvs)
600+
evaluate_varmap!(op, dvs; limit = substitution_limit)
507601

508602
u0 = better_varmap_to_vars(
509603
op, dvs; tofloat = true, use_union = false,
@@ -515,7 +609,8 @@ function process_SciMLProblem(
515609

516610
check_eqs_u0(eqs, dvs, u0; check_length, kwargs...)
517611

518-
evaluate_varmap!(op, ps)
612+
check_cyclic_dependency && check_substitution_cycles(op, ps)
613+
evaluate_varmap!(op, ps; limit = substitution_limit)
519614
if is_split(sys)
520615
p = MTKParameters(sys, op)
521616
else
@@ -527,6 +622,7 @@ function process_SciMLProblem(
527622
du0map = to_varmap(du0map, ddvs)
528623
merge!(op, du0map)
529624

625+
check_cyclic_dependency && check_substitution_cycles(op, ddvs)
530626
du0 = varmap_to_vars(du0map, ddvs; toterm = identity,
531627
tofloat = true)
532628
kwargs = merge(kwargs, (; ddvs))

test/initial_values.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,18 @@ end
149149
pmap = [c1 => 5.0, c2 => 1.0, c3 => 1.2]
150150
oprob = ODEProblem(osys, u0map, (0.0, 10.0), pmap)
151151
end
152+
153+
@testset "Cyclic dependency checking and substitution limits" begin
154+
@variables x(t) y(t)
155+
@mtkbuild sys = ODESystem([D(x) ~ x, D(y) ~ y], t)
156+
@test_throws ModelingToolkit.SubstitutionCycleError ODEProblem(
157+
sys, [x => 2y, y => 3x], (0.0, 1.0); build_initializeprob = false)
158+
@test_throws ModelingToolkit.UnexpectedSymbolicValueInVarmap ODEProblem(
159+
sys, [x => 2y, y => 3x], (0.0, 1.0);
160+
build_initializeprob = false, check_cyclic_dependency = false)
161+
162+
@parameters p q
163+
@mtkbuild sys = ODESystem([D(x) ~ x * p, D(y) ~ y * q], t)
164+
@test_throws ModelingToolkit.SubstitutionCycleError ODEProblem(
165+
sys, [x => 1, y => 2], (0.0, 1.0), [p => 2q, q => 3p]; build_initializeprob = false)
166+
end

0 commit comments

Comments
 (0)