diff --git a/src/remake.jl b/src/remake.jl index 15a26b338..b84625fde 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -515,6 +515,7 @@ function varmap_get(varmap, var, default = nothing) return default end +anydict(d::Dict{Any, Any}) = d anydict(d) = Dict{Any, Any}(d) anydict() = Dict{Any, Any}() @@ -658,8 +659,9 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0) isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0) isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p - u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, u0) - for (k, v) in u0) + for (k, v) in u0 + u0[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, u0) + end isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0) isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p @@ -668,8 +670,9 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0) # This is sort of an implicit dependency on MTK. The values of `u` won't actually be # used, since any state symbols in the expression were substituted out earlier. temp_state = ProblemState(; u = state_values(prob), p = p, t = t0) - u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state) - for (k, v) in u0) + for (k, v) in u0 + u0[k] = symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state) + end return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p end @@ -677,8 +680,9 @@ function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0) isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p) isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p)) - p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, p) - for (k, v) in p) + for (k, v) in p + p[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, p) + end isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p) isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p)) @@ -687,8 +691,9 @@ function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0) # this is sort of an implicit dependency on MTK. The values of `p` won't actually be # used, since any parameter symbols in the expression were substituted out earlier. temp_state = ProblemState(; u = u0, p = parameter_values(prob), t = t0) - p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state) - for (k, v) in p) + for (k, v) in p + p[k] = symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state) + end return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p)) end @@ -700,20 +705,14 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0) return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), remake_buffer(prob, parameter_values(prob), keys(p), values(p)) end - if !isu0dep - u0 = remake_buffer(prob, state_values(prob), keys(u0), values(u0)) - return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true), t0) - end - if !ispdep - p = remake_buffer(prob, parameter_values(prob), keys(p), values(p)) - return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false), t0) - end varmap = merge(u0, p) - u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap) - for (k, v) in u0) - p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap) - for (k, v) in p) + for (k, v) in u0 + u0[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap) + end + for (k, v) in p + p[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap) + end return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), remake_buffer(prob, parameter_values(prob), keys(p), values(p)) end diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index eec0b6baf..38d379ddc 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -219,3 +219,13 @@ end @test newoprob.ps[k] == [2.0, 3.0, 4.0, 5.0] @test newoprob[V] == [1.5, 2.5] end + +@testset "remake with parameter dependent on observed" begin + @variables x(t) y(t) + @parameters p = x + y + @mtkbuild sys = ODESystem([D(x) ~ x, p ~ x + y], t) + prob = ODEProblem(sys, [x => 1.0, y => 2.0], (0.0, 1.0)) + @test prob.ps[p] ≈ 3.0 + prob2 = remake(prob; u0 = [y => 3.0], p = Dict()) + @test prob2.ps[p] ≈ 4.0 +end