Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 53 additions & 53 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,21 @@ function varmap_get(varmap, var, default = nothing)
return default
end

"""
$(TYPEDSIGNATURES)

Check if `varmap::Dict{Any, Any}` contains cyclic values for any symbolic variables in
`syms`. Falls back on the basis of `symbolic_container(indp)`. Returns `false` by default.
"""
function detect_cycles(indp, varmap, syms)
if hasmethod(symbolic_container, Tuple{typeof(indp)}) &&
(sc = symbolic_container(indp)) != indp
return detect_cycles(sc, varmap, syms)
else
return false
end
end

anydict(d::Dict{Any, Any}) = d
anydict(d) = Dict{Any, Any}(d)
anydict() = Dict{Any, Any}()
Expand Down Expand Up @@ -560,14 +575,24 @@ function _updated_u0_p_internal(
end

function fill_u0(prob, u0; defs = nothing, use_defaults = false)
vsyms = variable_symbols(prob)
idx_to_vsym = anydict(variable_index(prob, sym) => sym for sym in vsyms)
fill_vars(prob, u0; defs, use_defaults, allsyms = variable_symbols(prob),
index_function = variable_index)
end

function fill_p(prob, p; defs = nothing, use_defaults = false)
fill_vars(prob, p; defs, use_defaults, allsyms = parameter_symbols(prob),
index_function = parameter_index)
end

function fill_vars(
prob, varmap; defs = nothing, use_defaults = false, allsyms, index_function)
idx_to_vsym = anydict(index_function(prob, sym) => sym for sym in allsyms)
sym_to_idx = anydict()
idx_to_sym = anydict()
idx_to_val = anydict()
for (k, v) in u0
for (k, v) in varmap
v === nothing && continue
idx = variable_index(prob, k)
idx = index_function(prob, k)
idx === nothing && continue
if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic()
idx = (idx,)
Expand All @@ -582,9 +607,9 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false)
idx_to_val[ii] = vv
end
end
for sym in vsyms
for sym in allsyms
haskey(sym_to_idx, sym) && continue
idx = variable_index(prob, sym)
idx = index_function(prob, sym)
haskey(idx_to_val, idx) && continue
sym_to_idx[sym] = idx
idx_to_sym[idx] = sym
Expand All @@ -593,72 +618,40 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false)
(symbolic_type(defval) != NotSymbolic() || use_defaults)
defval
else
getu(prob, sym)(prob)
getsym(prob, sym)(prob)
end
end
newvals = anydict()
for (idx, val) in idx_to_val
newvals[idx_to_sym[idx]] = val
end
for (k, v) in u0
for (k, v) in varmap
haskey(sym_to_idx, k) && continue
newvals[k] = v
end
return newvals
end

function fill_p(prob, p; defs = nothing, use_defaults = false)
psyms = parameter_symbols(prob)
idx_to_psym = anydict(parameter_index(prob, sym) => sym for sym in psyms)
sym_to_idx = anydict()
idx_to_sym = anydict()
idx_to_val = anydict()
for (k, v) in p
v === nothing && continue
idx = parameter_index(prob, k)
idx === nothing && continue
if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic()
idx = (idx,)
k = (k,)
v = (v,)
end
for (kk, vv, ii) in zip(k, v, idx)
sym_to_idx[kk] = ii
kk = idx_to_psym[ii]
sym_to_idx[kk] = ii
idx_to_sym[ii] = kk
idx_to_val[ii] = vv
end
end
for sym in psyms
haskey(sym_to_idx, sym) && continue
idx = parameter_index(prob, sym)
haskey(idx_to_val, idx) && continue
sym_to_idx[sym] = idx
idx_to_sym[idx] = sym
idx_to_val[idx] = if defs !== nothing &&
(defval = varmap_get(defs, sym)) !== nothing &&
(symbolic_type(defval) != NotSymbolic() || use_defaults)
defval
else
getp(prob, sym)(prob)
end
end
newvals = anydict()
for (idx, val) in idx_to_val
newvals[idx_to_sym[idx]] = val
end
for (k, v) in p
haskey(sym_to_idx, k) && continue
newvals[k] = v
struct CyclicDependencyError <: Exception
varmap::Dict{Any, Any}
vars::Any
end

function Base.showerror(io::IO, err::CyclicDependencyError)
println(io, "Detected cyclic dependency in initial values:")
for (k, v) in err.varmap
println(io, k, " => ", "v")
end
return newvals
println(io, "While trying to solve for variables: ", err.vars)
end

function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0)
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p

if detect_cycles(prob, u0, variable_symbols(prob))
throw(CyclicDependencyError(u0, variable_symbols(prob)))
end
for (k, v) in u0
u0[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, u0)
end
Expand All @@ -680,6 +673,9 @@ function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0)
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))

if detect_cycles(prob, p, parameter_symbols(prob))
throw(CyclicDependencyError(p, parameter_symbols(prob)))
end
for (k, v) in p
p[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, p)
end
Expand Down Expand Up @@ -707,6 +703,10 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0)
end

varmap = merge(u0, p)
allsyms = [variable_symbols(prob); parameter_symbols(prob)]
if detect_cycles(prob, varmap, allsyms)
throw(CyclicDependencyError(varmap, allsyms))
end
if is_time_dependent(prob)
varmap[only(independent_variable_symbols(prob))] = t0
end
Expand Down
2 changes: 1 addition & 1 deletion src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ indices that can be plotted as continuous variables. This is useful for systems
that store auxiliary variables in the state vector which are not meant to be
used for plotting.
"""
plottable_indices(x:: AbstractArray) = 1:length(x)
plottable_indices(x::AbstractArray) = 1:length(x)
plottable_indices(x::Number) = 1

@recipe function f(sol::AbstractTimeseriesSolution;
Expand Down
26 changes: 26 additions & 0 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,29 @@ end
prob2 = remake(prob; u0 = [x => t + 3.0])
@test prob2[x] ≈ 3.0
end

@static if length(methods(SciMLBase.detect_cycles)) == 1
function SciMLBase.detect_cycles(
::ModelingToolkit.AbstractSystem, varmap::Dict{Any, Any}, vars)
for sym in vars
if symbolic_type(ModelingToolkit.fixpoint_sub(sym, varmap; maxiters = 10)) !=
NotSymbolic()
return true
end
end
return false
end
end

@testset "Cycle detection" begin
@variables x(t) y(t)
@parameters p q
@mtkbuild sys = ODESystem([D(x) ~ x * p, D(y) ~ y * q], t)

prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 1.0), [p => 1.0, q => 1.0])
@test_throws SciMLBase.CyclicDependencyError remake(
prob; u0 = [x => 2y + 3, y => 2x + 1])
@test_throws SciMLBase.CyclicDependencyError remake(prob; p = [p => 2q + 1, q => p + 3])
@test_throws SciMLBase.CyclicDependencyError remake(
prob; u0 = [x => 2y + p, y => q + 3], p = [p => x + y, q => p + 3])
end
Loading