Skip to content

Commit a1b249f

Browse files
feat: add circular dependency checking, substitution limit
1 parent c45d100 commit a1b249f

File tree

2 files changed

+146
-6
lines changed

2 files changed

+146
-6
lines changed

src/systems/problem_utils.jl

Lines changed: 121 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,25 @@ function add_parameter_dependencies!(sys::AbstractSystem, varmap::AbstractDict)
250250
add_observed_equations!(varmap, parameter_dependencies(sys))
251251
end
252252

253+
struct UnexpectedSymbolicValueInVarmap <: Exception
254+
sym::Any
255+
val::Any
256+
end
257+
258+
function Base.showerror(io::IO, err::UnexpectedSymbolicValueInVarmap)
259+
println(io,
260+
"""
261+
Found symbolic value $(err.val) for variable $(err.sym). You may be missing an \
262+
initial condition or have cyclic initial conditions. If this is intended, pass \
263+
`symbolic_u0 = true`. In case the initial conditions are not cyclic but \
264+
require more substitutions to resolve, increase `substitution_limit`. To report \
265+
cycles in initial conditions of unknowns/parameters, pass \
266+
`warn_cyclic_dependency = true`. If the cycles are still not reported, you \
267+
may need to pass a larger value for `circular_dependency_max_cycle_length` \
268+
or `circular_dependency_max_cycles`.
269+
""")
270+
end
271+
253272
"""
254273
$(TYPEDSIGNATURES)
255274
@@ -278,6 +297,12 @@ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector;
278297
isempty(missing_vars) || throw(MissingVariablesError(missing_vars))
279298
end
280299
vals = map(x -> varmap[x], vars)
300+
if !allow_symbolic
301+
for (sym, val) in zip(vars, vals)
302+
symbolic_type(val) == NotSymbolic() && continue
303+
throw(UnexpectedSymbolicValueInVarmap(sym, val))
304+
end
305+
end
281306

282307
if container_type <: Union{AbstractDict, Tuple, Nothing, SciMLBase.NullParameters}
283308
container_type = Array
@@ -298,17 +323,68 @@ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector;
298323
end
299324
end
300325

326+
"""
327+
$(TYPEDSIGNATURES)
328+
329+
Check if any of the substitution rules in `varmap` lead to cycles involving
330+
variables in `vars`. Return a vector of vectors containing all the variables
331+
in each cycle.
332+
333+
Keyword arguments:
334+
- `max_cycle_length`: The maximum length (number of variables) of detected cycles.
335+
- `max_cycles`: The maximum number of cycles to report.
336+
"""
337+
function check_substitution_cycles(
338+
varmap::AbstractDict, vars; max_cycle_length = length(varmap), max_cycles = 10)
339+
# ordered set so that `vars` are the first `k` in the list
340+
allvars = OrderedSet{Any}(vars)
341+
union!(allvars, keys(varmap))
342+
allvars = collect(allvars)
343+
var_to_idx = Dict(allvars .=> eachindex(allvars))
344+
graph = SimpleDiGraph(length(allvars))
345+
346+
buffer = Set()
347+
for (k, v) in varmap
348+
kidx = var_to_idx[k]
349+
if symbolic_type(v) != NotSymbolic()
350+
vars!(buffer, v)
351+
for var in buffer
352+
haskey(var_to_idx, var) || continue
353+
add_edge!(graph, kidx, var_to_idx[var])
354+
end
355+
elseif v isa AbstractArray
356+
for val in v
357+
vars!(buffer, val)
358+
end
359+
for var in buffer
360+
haskey(var_to_idx, var) || continue
361+
add_edge!(graph, kidx, var_to_idx[var])
362+
end
363+
end
364+
empty!(buffer)
365+
end
366+
367+
# detect at most 100 cycles involving at most `length(varmap)` vertices
368+
cycles = Graphs.simplecycles_limited_length(graph, max_cycle_length, max_cycles)
369+
# only count those which contain variables in `vars`
370+
filter!(Base.Fix1(any, <=(length(vars))), cycles)
371+
372+
map(cycles) do cycle
373+
map(Base.Fix1(getindex, allvars), cycle)
374+
end
375+
end
376+
301377
"""
302378
$(TYPEDSIGNATURES)
303379
304380
Performs symbolic substitution on the values in `varmap` for the keys in `vars`, using
305381
`varmap` itself as the set of substitution rules. If an entry in `vars` is not a key
306382
in `varmap`, it is ignored.
307383
"""
308-
function evaluate_varmap!(varmap::AbstractDict, vars)
384+
function evaluate_varmap!(varmap::AbstractDict, vars; limit = 100)
309385
for k in vars
310386
haskey(varmap, k) || continue
311-
varmap[k] = fixpoint_sub(varmap[k], varmap)
387+
varmap[k] = fixpoint_sub(varmap[k], varmap; maxiters = limit)
312388
end
313389
end
314390

@@ -407,6 +483,14 @@ Keyword arguments:
407483
length of `u0` vector for consistency. If `false`, do not check with equations. This is
408484
forwarded to `check_eqs_u0`
409485
- `symbolic_u0` allows the returned `u0` to be an array of symbolics.
486+
- `warn_cyclic_dependency`: Whether to emit a warning listing out cycles in initial
487+
conditions provided for unknowns and parameters.
488+
- `circular_dependency_max_cycle_length`: Maximum length of cycle to check for.
489+
Only applicable if `warn_cyclic_dependency == true`.
490+
- `circular_dependency_max_cycles`: Maximum number of cycles to check for.
491+
Only applicable if `warn_cyclic_dependency == true`.
492+
- `substitution_limit`: The number times to substitute initial conditions into each
493+
other to attempt to arrive at a numeric value.
410494
411495
All other keyword arguments are passed as-is to `constructor`.
412496
"""
@@ -416,7 +500,11 @@ function process_SciMLProblem(
416500
warn_initialize_determined = true, initialization_eqs = [],
417501
eval_expression = false, eval_module = @__MODULE__, fully_determined = false,
418502
check_initialization_units = false, tofloat = true, use_union = false,
419-
u0_constructor = identity, du0map = nothing, check_length = true, symbolic_u0 = false, kwargs...)
503+
u0_constructor = identity, du0map = nothing, check_length = true,
504+
symbolic_u0 = false, warn_cyclic_dependency = false,
505+
circular_dependency_max_cycle_length = length(all_symbols(sys)),
506+
circular_dependency_max_cycles = 10,
507+
substitution_limit = 100, kwargs...)
420508
dvs = unknowns(sys)
421509
ps = parameters(sys)
422510
iv = has_iv(sys) ? get_iv(sys) : nothing
@@ -466,7 +554,8 @@ function process_SciMLProblem(
466554
initializeprob = ModelingToolkit.InitializationProblem(
467555
sys, t, u0map, pmap; guesses, warn_initialize_determined,
468556
initialization_eqs, eval_expression, eval_module, fully_determined,
469-
check_units = check_initialization_units)
557+
warn_cyclic_dependency, check_units = check_initialization_units,
558+
circular_dependency_max_cycle_length, circular_dependency_max_cycles)
470559
initializeprobmap = getu(initializeprob, unknowns(sys))
471560

472561
punknowns = [p
@@ -503,7 +592,20 @@ function process_SciMLProblem(
503592
add_observed!(sys, op)
504593
add_parameter_dependencies!(sys, op)
505594

506-
evaluate_varmap!(op, dvs)
595+
if warn_cyclic_dependency
596+
cycles = check_substitution_cycles(
597+
op, dvs; max_cycle_length = circular_dependency_max_cycle_length,
598+
max_cycles = circular_dependency_max_cycles)
599+
if !isempty(cycles)
600+
buffer = IOBuffer()
601+
for cycle in cycles
602+
println(buffer, cycle)
603+
end
604+
msg = String(take!(buffer))
605+
@warn "Cycles in unknowns:\n$msg"
606+
end
607+
end
608+
evaluate_varmap!(op, dvs; limit = substitution_limit)
507609

508610
u0 = better_varmap_to_vars(
509611
op, dvs; tofloat = true, use_union = false,
@@ -515,7 +617,20 @@ function process_SciMLProblem(
515617

516618
check_eqs_u0(eqs, dvs, u0; check_length, kwargs...)
517619

518-
evaluate_varmap!(op, ps)
620+
if warn_cyclic_dependency
621+
cycles = check_substitution_cycles(
622+
op, ps; max_cycle_length = circular_dependency_max_cycle_length,
623+
max_cycles = circular_dependency_max_cycles)
624+
if !isempty(cycles)
625+
buffer = IOBuffer()
626+
for cycle in cycles
627+
println(buffer, cycle)
628+
end
629+
msg = String(take!(buffer))
630+
@warn "Cycles in parameters:\n$msg"
631+
end
632+
end
633+
evaluate_varmap!(op, ps; limit = substitution_limit)
519634
if is_split(sys)
520635
p = MTKParameters(sys, op)
521636
else

test/initial_values.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,28 @@ end
149149
pmap = [c1 => 5.0, c2 => 1.0, c3 => 1.2]
150150
oprob = ODEProblem(osys, u0map, (0.0, 10.0), pmap)
151151
end
152+
153+
@testset "Cyclic dependency checking and substitution limits" begin
154+
@variables x(t) y(t)
155+
@mtkbuild sys = ODESystem(
156+
[D(x) ~ x, D(y) ~ y], t; initialization_eqs = [x ~ 2y + 3, y ~ 2x],
157+
guesses = [x => 2y, y => 2x])
158+
@test_warn ["Cycle", "unknowns", "x", "y"] try
159+
ODEProblem(sys, [], (0.0, 1.0), warn_cyclic_dependency = true)
160+
catch
161+
end
162+
@test_throws ModelingToolkit.UnexpectedSymbolicValueInVarmap ODEProblem(
163+
sys, [x => 2y + 1, y => 2x], (0.0, 1.0); build_initializeprob = false)
164+
165+
@parameters p q
166+
@mtkbuild sys = ODESystem(
167+
[D(x) ~ x * p, D(y) ~ y * q], t; guesses = [p => 1.0, q => 2.0])
168+
# "unknowns" because they are initialization unknowns
169+
@test_warn ["Cycle", "unknowns", "p", "q"] try
170+
ODEProblem(sys, [x => 1, y => 2], (0.0, 1.0),
171+
[p => 2q, q => 3p]; warn_cyclic_dependency = true)
172+
catch
173+
end
174+
@test_throws ModelingToolkit.UnexpectedSymbolicValueInVarmap ODEProblem(
175+
sys, [x => 1, y => 2], (0.0, 1.0), [p => 2q, q => 3p])
176+
end

0 commit comments

Comments
 (0)