diff --git a/src/remake.jl b/src/remake.jl index 0ffd91003..43812df11 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -509,6 +509,21 @@ function varmap_get(varmap, var, default = nothing) return default end +""" + $(TYPEDSIGNATURES) + +Check if `varmap::Dict{Any, Any}` contains cyclic values for any symbolic variables in +`syms`. Falls back on the basis of `symbolic_container(indp)`. Returns `false` by default. +""" +function detect_cycles(indp, varmap, syms) + if hasmethod(symbolic_container, Tuple{typeof(indp)}) && + (sc = symbolic_container(indp)) != indp + return detect_cycles(sc, varmap, syms) + else + return false + end +end + anydict(d::Dict{Any, Any}) = d anydict(d) = Dict{Any, Any}(d) anydict() = Dict{Any, Any}() @@ -560,14 +575,24 @@ function _updated_u0_p_internal( end function fill_u0(prob, u0; defs = nothing, use_defaults = false) - vsyms = variable_symbols(prob) - idx_to_vsym = anydict(variable_index(prob, sym) => sym for sym in vsyms) + fill_vars(prob, u0; defs, use_defaults, allsyms = variable_symbols(prob), + index_function = variable_index) +end + +function fill_p(prob, p; defs = nothing, use_defaults = false) + fill_vars(prob, p; defs, use_defaults, allsyms = parameter_symbols(prob), + index_function = parameter_index) +end + +function fill_vars( + prob, varmap; defs = nothing, use_defaults = false, allsyms, index_function) + idx_to_vsym = anydict(index_function(prob, sym) => sym for sym in allsyms) sym_to_idx = anydict() idx_to_sym = anydict() idx_to_val = anydict() - for (k, v) in u0 + for (k, v) in varmap v === nothing && continue - idx = variable_index(prob, k) + idx = index_function(prob, k) idx === nothing && continue if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic() idx = (idx,) @@ -582,9 +607,9 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false) idx_to_val[ii] = vv end end - for sym in vsyms + for sym in allsyms haskey(sym_to_idx, sym) && continue - idx = variable_index(prob, sym) + idx = index_function(prob, sym) haskey(idx_to_val, idx) && continue sym_to_idx[sym] = idx idx_to_sym[idx] = sym @@ -593,72 +618,40 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false) (symbolic_type(defval) != NotSymbolic() || use_defaults) defval else - getu(prob, sym)(prob) + getsym(prob, sym)(prob) end end newvals = anydict() for (idx, val) in idx_to_val newvals[idx_to_sym[idx]] = val end - for (k, v) in u0 + for (k, v) in varmap haskey(sym_to_idx, k) && continue newvals[k] = v end return newvals end -function fill_p(prob, p; defs = nothing, use_defaults = false) - psyms = parameter_symbols(prob) - idx_to_psym = anydict(parameter_index(prob, sym) => sym for sym in psyms) - sym_to_idx = anydict() - idx_to_sym = anydict() - idx_to_val = anydict() - for (k, v) in p - v === nothing && continue - idx = parameter_index(prob, k) - idx === nothing && continue - if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic() - idx = (idx,) - k = (k,) - v = (v,) - end - for (kk, vv, ii) in zip(k, v, idx) - sym_to_idx[kk] = ii - kk = idx_to_psym[ii] - sym_to_idx[kk] = ii - idx_to_sym[ii] = kk - idx_to_val[ii] = vv - end - end - for sym in psyms - haskey(sym_to_idx, sym) && continue - idx = parameter_index(prob, sym) - haskey(idx_to_val, idx) && continue - sym_to_idx[sym] = idx - idx_to_sym[idx] = sym - idx_to_val[idx] = if defs !== nothing && - (defval = varmap_get(defs, sym)) !== nothing && - (symbolic_type(defval) != NotSymbolic() || use_defaults) - defval - else - getp(prob, sym)(prob) - end - end - newvals = anydict() - for (idx, val) in idx_to_val - newvals[idx_to_sym[idx]] = val - end - for (k, v) in p - haskey(sym_to_idx, k) && continue - newvals[k] = v +struct CyclicDependencyError <: Exception + varmap::Dict{Any, Any} + vars::Any +end + +function Base.showerror(io::IO, err::CyclicDependencyError) + println(io, "Detected cyclic dependency in initial values:") + for (k, v) in err.varmap + println(io, k, " => ", "v") end - return newvals + println(io, "While trying to solve for variables: ", err.vars) end function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0) isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0) isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p + if detect_cycles(prob, u0, variable_symbols(prob)) + throw(CyclicDependencyError(u0, variable_symbols(prob))) + end for (k, v) in u0 u0[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, u0) end @@ -680,6 +673,9 @@ function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0) isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p) isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p)) + if detect_cycles(prob, p, parameter_symbols(prob)) + throw(CyclicDependencyError(p, parameter_symbols(prob))) + end for (k, v) in p p[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, p) end @@ -707,6 +703,10 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0) end varmap = merge(u0, p) + allsyms = [variable_symbols(prob); parameter_symbols(prob)] + if detect_cycles(prob, varmap, allsyms) + throw(CyclicDependencyError(varmap, allsyms)) + end if is_time_dependent(prob) varmap[only(independent_variable_symbols(prob))] = t0 end diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 7cca426f5..41a2681f3 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -209,7 +209,7 @@ indices that can be plotted as continuous variables. This is useful for systems that store auxiliary variables in the state vector which are not meant to be used for plotting. """ -plottable_indices(x:: AbstractArray) = 1:length(x) +plottable_indices(x::AbstractArray) = 1:length(x) plottable_indices(x::Number) = 1 @recipe function f(sol::AbstractTimeseriesSolution; diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index d0df44658..75fcec09d 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -248,3 +248,29 @@ end prob2 = remake(prob; u0 = [x => t + 3.0]) @test prob2[x] ≈ 3.0 end + +@static if length(methods(SciMLBase.detect_cycles)) == 1 + function SciMLBase.detect_cycles( + ::ModelingToolkit.AbstractSystem, varmap::Dict{Any, Any}, vars) + for sym in vars + if symbolic_type(ModelingToolkit.fixpoint_sub(sym, varmap; maxiters = 10)) != + NotSymbolic() + return true + end + end + return false + end +end + +@testset "Cycle detection" begin + @variables x(t) y(t) + @parameters p q + @mtkbuild sys = ODESystem([D(x) ~ x * p, D(y) ~ y * q], t) + + prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 1.0), [p => 1.0, q => 1.0]) + @test_throws SciMLBase.CyclicDependencyError remake( + prob; u0 = [x => 2y + 3, y => 2x + 1]) + @test_throws SciMLBase.CyclicDependencyError remake(prob; p = [p => 2q + 1, q => p + 3]) + @test_throws SciMLBase.CyclicDependencyError remake( + prob; u0 = [x => 2y + p, y => q + 3], p = [p => x + y, q => p + 3]) +end