Skip to content

Commit 0b7d8fa

Browse files
feat: add cycle detection in initial conditions
1 parent e4e7c8b commit 0b7d8fa

File tree

1 file changed

+39
-1
lines changed

1 file changed

+39
-1
lines changed

src/remake.jl

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,21 @@ function varmap_get(varmap, var, default = nothing)
509509
return default
510510
end
511511

512+
"""
513+
$(TYPEDSIGNATURES)
514+
515+
Check if `varmap::Dict{Any, Any}` contains cyclic values for any symbolic variables in
516+
`syms`. Falls back on the basis of `symbolic_container(indp)`. Returns `false` by default.
517+
"""
518+
function detect_cycles(indp, varmap, syms)
519+
if hasmethod(symbolic_container, Tuple{typeof(indp)}) &&
520+
(sc = symbolic_container(indp)) != indp
521+
return detect_cycles(sc, varmap, syms)
522+
else
523+
return false
524+
end
525+
end
526+
512527
anydict(d::Dict{Any, Any}) = d
513528
anydict(d) = Dict{Any, Any}(d)
514529
anydict() = Dict{Any, Any}()
@@ -571,7 +586,7 @@ end
571586

572587
function fill_vars(
573588
prob, varmap; defs = nothing, use_defaults = false, allsyms, index_function)
574-
idx_to_vsym = anydict(variable_index(prob, sym) => sym for sym in allsyms)
589+
idx_to_vsym = anydict(index_function(prob, sym) => sym for sym in allsyms)
575590
sym_to_idx = anydict()
576591
idx_to_sym = anydict()
577592
idx_to_val = anydict()
@@ -617,10 +632,26 @@ function fill_vars(
617632
return newvals
618633
end
619634

635+
struct CyclicDependencyError <: Exception
636+
varmap::Dict{Any, Any}
637+
vars::Any
638+
end
639+
640+
function Base.showerror(io::IO, err::CyclicDependencyError)
641+
println(io, "Detected cyclic dependency in initial values:")
642+
for (k, v) in err.varmap
643+
println(io, k, " => ", "v")
644+
end
645+
println(io, "While trying to solve for variables: ", err.vars)
646+
end
647+
620648
function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0)
621649
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
622650
isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p
623651

652+
if detect_cycles(prob, u0, variable_symbols(prob))
653+
throw(CyclicDependencyError(u0, variable_symbols(prob)))
654+
end
624655
for (k, v) in u0
625656
u0[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, u0)
626657
end
@@ -642,6 +673,9 @@ function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0)
642673
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
643674
isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))
644675

676+
if detect_cycles(prob, p, parameter_symbols(prob))
677+
throw(CyclicDependencyError(p, parameter_symbols(prob)))
678+
end
645679
for (k, v) in p
646680
p[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, p)
647681
end
@@ -669,6 +703,10 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0)
669703
end
670704

671705
varmap = merge(u0, p)
706+
allsyms = [variable_symbols(prob); parameter_symbols(prob)]
707+
if detect_cycles(prob, varmap, allsyms)
708+
throw(CyclicDependencyError(varmap, allsyms))
709+
end
672710
if is_time_dependent(prob)
673711
varmap[only(independent_variable_symbols(prob))] = t0
674712
end

0 commit comments

Comments
 (0)