Skip to content

Commit da1552a

Browse files
feat: add circular dependency checking, substitution limit
1 parent c45d100 commit da1552a

File tree

2 files changed

+142
-6
lines changed

2 files changed

+142
-6
lines changed

src/systems/problem_utils.jl

Lines changed: 117 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,21 @@ function add_parameter_dependencies!(sys::AbstractSystem, varmap::AbstractDict)
250250
add_observed_equations!(varmap, parameter_dependencies(sys))
251251
end
252252

253+
struct UnexpectedSymbolicValueInVarmap <: Exception
254+
sym::Any
255+
val::Any
256+
end
257+
258+
function Base.showerror(io::IO, err::UnexpectedSymbolicValueInVarmap)
259+
println(io,
260+
"""
261+
Found symbolic value $(err.val) for variable $(err.sym). You may be missing an \
262+
initial condition or have cyclic initial conditions. If this is intended, pass \
263+
`symbolic_u0 = true`. In case the initial conditions are not cyclic but \
264+
require more substitutions to resolve, increase `substitution_limit`.
265+
""")
266+
end
267+
253268
"""
254269
$(TYPEDSIGNATURES)
255270
@@ -278,6 +293,12 @@ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector;
278293
isempty(missing_vars) || throw(MissingVariablesError(missing_vars))
279294
end
280295
vals = map(x -> varmap[x], vars)
296+
if !allow_symbolic
297+
for (sym, val) in zip(vars, vals)
298+
symbolic_type(val) == NotSymbolic() && continue
299+
throw(UnexpectedSymbolicValueInVarmap(sym, val))
300+
end
301+
end
281302

282303
if container_type <: Union{AbstractDict, Tuple, Nothing, SciMLBase.NullParameters}
283304
container_type = Array
@@ -298,17 +319,68 @@ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector;
298319
end
299320
end
300321

322+
"""
323+
$(TYPEDSIGNATURES)
324+
325+
Check if any of the substitution rules in `varmap` lead to cycles involving
326+
variables in `vars`. Return a vector of vectors containing all the variables
327+
in each cycle.
328+
329+
Keyword arguments:
330+
- `max_cycle_length`: The maximum length (number of variables) of detected cycles.
331+
- `max_cycles`: The maximum number of cycles to report.
332+
"""
333+
function check_substitution_cycles(
334+
varmap::AbstractDict, vars; max_cycle_length = length(varmap), max_cycles = 10)
335+
# ordered set so that `vars` are the first `k` in the list
336+
allvars = OrderedSet{Any}(vars)
337+
union!(allvars, keys(varmap))
338+
allvars = collect(allvars)
339+
var_to_idx = Dict(allvars .=> eachindex(allvars))
340+
graph = SimpleDiGraph(length(allvars))
341+
342+
buffer = Set()
343+
for (k, v) in varmap
344+
kidx = var_to_idx[k]
345+
if symbolic_type(v) != NotSymbolic()
346+
vars!(buffer, v)
347+
for var in buffer
348+
haskey(var_to_idx, var) || continue
349+
add_edge!(graph, kidx, var_to_idx[var])
350+
end
351+
elseif v isa AbstractArray
352+
for val in v
353+
vars!(buffer, val)
354+
end
355+
for var in buffer
356+
haskey(var_to_idx, var) || continue
357+
add_edge!(graph, kidx, var_to_idx[var])
358+
end
359+
end
360+
empty!(buffer)
361+
end
362+
363+
# detect at most 100 cycles involving at most `length(varmap)` vertices
364+
cycles = Graphs.simplecycles_limited_length(graph, max_cycle_length, max_cycles)
365+
# only count those which contain variables in `vars`
366+
filter!(Base.Fix1(any, <=(length(vars))), cycles)
367+
368+
map(cycles) do cycle
369+
map(Base.Fix1(getindex, allvars), cycle)
370+
end
371+
end
372+
301373
"""
302374
$(TYPEDSIGNATURES)
303375
304376
Performs symbolic substitution on the values in `varmap` for the keys in `vars`, using
305377
`varmap` itself as the set of substitution rules. If an entry in `vars` is not a key
306378
in `varmap`, it is ignored.
307379
"""
308-
function evaluate_varmap!(varmap::AbstractDict, vars)
380+
function evaluate_varmap!(varmap::AbstractDict, vars; limit = 100)
309381
for k in vars
310382
haskey(varmap, k) || continue
311-
varmap[k] = fixpoint_sub(varmap[k], varmap)
383+
varmap[k] = fixpoint_sub(varmap[k], varmap; maxiters = limit)
312384
end
313385
end
314386

@@ -407,6 +479,14 @@ Keyword arguments:
407479
length of `u0` vector for consistency. If `false`, do not check with equations. This is
408480
forwarded to `check_eqs_u0`
409481
- `symbolic_u0` allows the returned `u0` to be an array of symbolics.
482+
- `warn_cyclic_dependency`: Whether to emit a warning listing out cycles in initial
483+
conditions provided for unknowns and parameters.
484+
- `circular_dependency_max_cycle_length`: Maximum length of cycle to check for.
485+
Only applicable if `warn_cyclic_dependency == true`.
486+
- `circular_dependency_max_cycles`: Maximum number of cycles to check for.
487+
Only applicable if `warn_cyclic_dependency == true`.
488+
- `substitution_limit`: The number times to substitute initial conditions into each
489+
other to attempt to arrive at a numeric value.
410490
411491
All other keyword arguments are passed as-is to `constructor`.
412492
"""
@@ -416,7 +496,11 @@ function process_SciMLProblem(
416496
warn_initialize_determined = true, initialization_eqs = [],
417497
eval_expression = false, eval_module = @__MODULE__, fully_determined = false,
418498
check_initialization_units = false, tofloat = true, use_union = false,
419-
u0_constructor = identity, du0map = nothing, check_length = true, symbolic_u0 = false, kwargs...)
499+
u0_constructor = identity, du0map = nothing, check_length = true,
500+
symbolic_u0 = false, warn_cyclic_dependency = false,
501+
circular_dependency_max_cycle_length = length(all_symbols(sys)),
502+
circular_dependency_max_cycles = 10,
503+
substitution_limit = length(observed(sys)), kwargs...)
420504
dvs = unknowns(sys)
421505
ps = parameters(sys)
422506
iv = has_iv(sys) ? get_iv(sys) : nothing
@@ -466,7 +550,8 @@ function process_SciMLProblem(
466550
initializeprob = ModelingToolkit.InitializationProblem(
467551
sys, t, u0map, pmap; guesses, warn_initialize_determined,
468552
initialization_eqs, eval_expression, eval_module, fully_determined,
469-
check_units = check_initialization_units)
553+
warn_cyclic_dependency, check_units = check_initialization_units,
554+
circular_dependency_max_cycle_length, circular_dependency_max_cycles)
470555
initializeprobmap = getu(initializeprob, unknowns(sys))
471556

472557
punknowns = [p
@@ -503,7 +588,20 @@ function process_SciMLProblem(
503588
add_observed!(sys, op)
504589
add_parameter_dependencies!(sys, op)
505590

506-
evaluate_varmap!(op, dvs)
591+
if warn_cyclic_dependency
592+
cycles = check_substitution_cycles(
593+
op, dvs; max_cycle_length = circular_dependency_max_cycle_length,
594+
max_cycles = circular_dependency_max_cycles)
595+
if !isempty(cycles)
596+
buffer = IOBuffer()
597+
for cycle in cycles
598+
println(buffer, cycle)
599+
end
600+
msg = String(take!(buffer))
601+
@warn "Cycles in unknowns:\n$msg"
602+
end
603+
end
604+
evaluate_varmap!(op, dvs; limit = substitution_limit)
507605

508606
u0 = better_varmap_to_vars(
509607
op, dvs; tofloat = true, use_union = false,
@@ -515,7 +613,20 @@ function process_SciMLProblem(
515613

516614
check_eqs_u0(eqs, dvs, u0; check_length, kwargs...)
517615

518-
evaluate_varmap!(op, ps)
616+
if warn_cyclic_dependency
617+
cycles = check_substitution_cycles(
618+
op, ps; max_cycle_length = circular_dependency_max_cycle_length,
619+
max_cycles = circular_dependency_max_cycles)
620+
if !isempty(cycles)
621+
buffer = IOBuffer()
622+
for cycle in cycles
623+
println(buffer, cycle)
624+
end
625+
msg = String(take!(buffer))
626+
@warn "Cycles in parameters:\n$msg"
627+
end
628+
end
629+
evaluate_varmap!(op, ps; limit = substitution_limit)
519630
if is_split(sys)
520631
p = MTKParameters(sys, op)
521632
else

test/initial_values.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,28 @@ end
149149
pmap = [c1 => 5.0, c2 => 1.0, c3 => 1.2]
150150
oprob = ODEProblem(osys, u0map, (0.0, 10.0), pmap)
151151
end
152+
153+
@testset "Cyclic dependency checking and substitution limits" begin
154+
@variables x(t) y(t)
155+
@mtkbuild sys = ODESystem(
156+
[D(x) ~ x, D(y) ~ y], t; initialization_eqs = [x ~ 2y + 3, y ~ 2x],
157+
guesses = [x => 2y, y => 2x])
158+
@test_warn ["Cycle", "unknowns", "x", "y"] try
159+
ODEProblem(sys, [], (0.0, 1.0), warn_cyclic_dependency = true)
160+
catch
161+
end
162+
@test_throws ModelingToolkit.UnexpectedSymbolicValueInVarmap ODEProblem(
163+
sys, [], (0.0, 1.0))
164+
165+
@parameters p q
166+
@mtkbuild sys = ODESystem(
167+
[D(x) ~ x * p, D(y) ~ y * q], t; guesses = [p => 1.0, q => 2.0])
168+
# "unknowns" because they are initialization unknowns
169+
@test_warn ["Cycle", "unknowns", "p", "q"] try
170+
ODEProblem(sys, [x => 1, y => 2], (0.0, 1.0),
171+
[p => 2q, q => 3p]; warn_cyclic_dependency = true)
172+
catch
173+
end
174+
@test_throws ModelingToolkit.UnexpectedSymbolicValueInVarmap ODEProblem(
175+
sys, [x => 1, y => 2], (0.0, 1.0), [p => 2q, q => 3p])
176+
end

0 commit comments

Comments
 (0)