Skip to content

Commit d366481

Browse files
Merge pull request #852 from AayushSabharwal/as/remake-debugging
feat: add cycle detection in initial conditions
2 parents 0453006 + b3949ed commit d366481

File tree

3 files changed

+79
-53
lines changed

3 files changed

+79
-53
lines changed

src/remake.jl

Lines changed: 52 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,21 @@ function varmap_get(varmap, var, default = nothing)
509509
return default
510510
end
511511

512+
"""
513+
$(TYPEDSIGNATURES)
514+
515+
Check if `varmap::Dict{Any, Any}` contains cyclic values for any symbolic variables in
516+
`syms`. Falls back on the basis of `symbolic_container(indp)`. Returns `false` by default.
517+
"""
518+
function detect_cycles(indp, varmap, syms)
519+
if hasmethod(symbolic_container, Tuple{typeof(indp)}) &&
520+
(sc = symbolic_container(indp)) != indp
521+
return detect_cycles(sc, varmap, syms)
522+
else
523+
return false
524+
end
525+
end
526+
512527
anydict(d::Dict{Any, Any}) = d
513528
anydict(d) = Dict{Any, Any}(d)
514529
anydict() = Dict{Any, Any}()
@@ -560,14 +575,24 @@ function _updated_u0_p_internal(
560575
end
561576

562577
function fill_u0(prob, u0; defs = nothing, use_defaults = false)
563-
vsyms = variable_symbols(prob)
564-
idx_to_vsym = anydict(variable_index(prob, sym) => sym for sym in vsyms)
578+
fill_vars(prob, u0; defs, use_defaults, allsyms = variable_symbols(prob),
579+
index_function = variable_index)
580+
end
581+
582+
function fill_p(prob, p; defs = nothing, use_defaults = false)
583+
fill_vars(prob, p; defs, use_defaults, allsyms = parameter_symbols(prob),
584+
index_function = parameter_index)
585+
end
586+
587+
function fill_vars(
588+
prob, varmap; defs = nothing, use_defaults = false, allsyms, index_function)
589+
idx_to_vsym = anydict(index_function(prob, sym) => sym for sym in allsyms)
565590
sym_to_idx = anydict()
566591
idx_to_sym = anydict()
567592
idx_to_val = anydict()
568-
for (k, v) in u0
593+
for (k, v) in varmap
569594
v === nothing && continue
570-
idx = variable_index(prob, k)
595+
idx = index_function(prob, k)
571596
idx === nothing && continue
572597
if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic()
573598
idx = (idx,)
@@ -582,9 +607,9 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false)
582607
idx_to_val[ii] = vv
583608
end
584609
end
585-
for sym in vsyms
610+
for sym in allsyms
586611
haskey(sym_to_idx, sym) && continue
587-
idx = variable_index(prob, sym)
612+
idx = index_function(prob, sym)
588613
haskey(idx_to_val, idx) && continue
589614
sym_to_idx[sym] = idx
590615
idx_to_sym[idx] = sym
@@ -600,65 +625,33 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false)
600625
for (idx, val) in idx_to_val
601626
newvals[idx_to_sym[idx]] = val
602627
end
603-
for (k, v) in u0
628+
for (k, v) in varmap
604629
haskey(sym_to_idx, k) && continue
605630
newvals[k] = v
606631
end
607632
return newvals
608633
end
609634

610-
function fill_p(prob, p; defs = nothing, use_defaults = false)
611-
psyms = parameter_symbols(prob)
612-
idx_to_psym = anydict(parameter_index(prob, sym) => sym for sym in psyms)
613-
sym_to_idx = anydict()
614-
idx_to_sym = anydict()
615-
idx_to_val = anydict()
616-
for (k, v) in p
617-
v === nothing && continue
618-
idx = parameter_index(prob, k)
619-
idx === nothing && continue
620-
if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic()
621-
idx = (idx,)
622-
k = (k,)
623-
v = (v,)
624-
end
625-
for (kk, vv, ii) in zip(k, v, idx)
626-
sym_to_idx[kk] = ii
627-
kk = idx_to_psym[ii]
628-
sym_to_idx[kk] = ii
629-
idx_to_sym[ii] = kk
630-
idx_to_val[ii] = vv
631-
end
632-
end
633-
for sym in psyms
634-
haskey(sym_to_idx, sym) && continue
635-
idx = parameter_index(prob, sym)
636-
haskey(idx_to_val, idx) && continue
637-
sym_to_idx[sym] = idx
638-
idx_to_sym[idx] = sym
639-
idx_to_val[idx] = if defs !== nothing &&
640-
(defval = varmap_get(defs, sym)) !== nothing &&
641-
(symbolic_type(defval) != NotSymbolic() || use_defaults)
642-
defval
643-
else
644-
getp(prob, sym)(prob)
645-
end
646-
end
647-
newvals = anydict()
648-
for (idx, val) in idx_to_val
649-
newvals[idx_to_sym[idx]] = val
650-
end
651-
for (k, v) in p
652-
haskey(sym_to_idx, k) && continue
653-
newvals[k] = v
635+
struct CyclicDependencyError <: Exception
636+
varmap::Dict{Any, Any}
637+
vars::Any
638+
end
639+
640+
function Base.showerror(io::IO, err::CyclicDependencyError)
641+
println(io, "Detected cyclic dependency in initial values:")
642+
for (k, v) in err.varmap
643+
println(io, k, " => ", "v")
654644
end
655-
return newvals
645+
println(io, "While trying to solve for variables: ", err.vars)
656646
end
657647

658648
function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0)
659649
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
660650
isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p
661651

652+
if detect_cycles(prob, u0, variable_symbols(prob))
653+
throw(CyclicDependencyError(u0, variable_symbols(prob)))
654+
end
662655
for (k, v) in u0
663656
u0[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, u0)
664657
end
@@ -680,6 +673,9 @@ function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0)
680673
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
681674
isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))
682675

676+
if detect_cycles(prob, p, parameter_symbols(prob))
677+
throw(CyclicDependencyError(p, parameter_symbols(prob)))
678+
end
683679
for (k, v) in p
684680
p[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, p)
685681
end
@@ -707,6 +703,10 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0)
707703
end
708704

709705
varmap = merge(u0, p)
706+
allsyms = [variable_symbols(prob); parameter_symbols(prob)]
707+
if detect_cycles(prob, varmap, allsyms)
708+
throw(CyclicDependencyError(varmap, allsyms))
709+
end
710710
if is_time_dependent(prob)
711711
varmap[only(independent_variable_symbols(prob))] = t0
712712
end

src/solutions/solution_interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ indices that can be plotted as continuous variables. This is useful for systems
209209
that store auxiliary variables in the state vector which are not meant to be
210210
used for plotting.
211211
"""
212-
plottable_indices(x:: AbstractArray) = 1:length(x)
212+
plottable_indices(x::AbstractArray) = 1:length(x)
213213
plottable_indices(x::Number) = 1
214214

215215
@recipe function f(sol::AbstractTimeseriesSolution;

test/downstream/modelingtoolkit_remake.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,29 @@ end
248248
prob2 = remake(prob; u0 = [x => t + 3.0])
249249
@test prob2[x] 3.0
250250
end
251+
252+
@static if length(methods(SciMLBase.detect_cycles)) == 1
253+
function SciMLBase.detect_cycles(
254+
::ModelingToolkit.AbstractSystem, varmap::Dict{Any, Any}, vars)
255+
for sym in vars
256+
if symbolic_type(ModelingToolkit.fixpoint_sub(sym, varmap; maxiters = 10)) !=
257+
NotSymbolic()
258+
return true
259+
end
260+
end
261+
return false
262+
end
263+
end
264+
265+
@testset "Cycle detection" begin
266+
@variables x(t) y(t)
267+
@parameters p q
268+
@mtkbuild sys = ODESystem([D(x) ~ x * p, D(y) ~ y * q], t)
269+
270+
prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 1.0), [p => 1.0, q => 1.0])
271+
@test_throws SciMLBase.CyclicDependencyError remake(
272+
prob; u0 = [x => 2y + 3, y => 2x + 1])
273+
@test_throws SciMLBase.CyclicDependencyError remake(prob; p = [p => 2q + 1, q => p + 3])
274+
@test_throws SciMLBase.CyclicDependencyError remake(
275+
prob; u0 = [x => 2y + p, y => q + 3], p = [p => x + y, q => p + 3])
276+
end

0 commit comments

Comments
 (0)