Skip to content

Commit bf913e7

Browse files
Merge pull request #832 from AayushSabharwal/as/remake-fix
fix: fix `remake` for parameters dependent on observed variables
2 parents d3611ca + 3a25351 commit bf913e7

File tree

2 files changed

+29
-20
lines changed

2 files changed

+29
-20
lines changed

src/remake.jl

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,7 @@ function varmap_get(varmap, var, default = nothing)
515515
return default
516516
end
517517

518+
anydict(d::Dict{Any, Any}) = d
518519
anydict(d) = Dict{Any, Any}(d)
519520
anydict() = Dict{Any, Any}()
520521

@@ -658,8 +659,9 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0)
658659
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
659660
isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p
660661

661-
u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, u0)
662-
for (k, v) in u0)
662+
for (k, v) in u0
663+
u0[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, u0)
664+
end
663665

664666
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
665667
isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p
@@ -668,17 +670,19 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0)
668670
# This is sort of an implicit dependency on MTK. The values of `u` won't actually be
669671
# used, since any state symbols in the expression were substituted out earlier.
670672
temp_state = ProblemState(; u = state_values(prob), p = p, t = t0)
671-
u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
672-
for (k, v) in u0)
673+
for (k, v) in u0
674+
u0[k] = symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
675+
end
673676
return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p
674677
end
675678

676679
function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0)
677680
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
678681
isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))
679682

680-
p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, p)
681-
for (k, v) in p)
683+
for (k, v) in p
684+
p[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, p)
685+
end
682686

683687
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
684688
isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))
@@ -687,8 +691,9 @@ function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0)
687691
# this is sort of an implicit dependency on MTK. The values of `p` won't actually be
688692
# used, since any parameter symbols in the expression were substituted out earlier.
689693
temp_state = ProblemState(; u = u0, p = parameter_values(prob), t = t0)
690-
p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
691-
for (k, v) in p)
694+
for (k, v) in p
695+
p[k] = symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
696+
end
692697
return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))
693698
end
694699

@@ -700,20 +705,14 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0)
700705
return remake_buffer(prob, state_values(prob), keys(u0), values(u0)),
701706
remake_buffer(prob, parameter_values(prob), keys(p), values(p))
702707
end
703-
if !isu0dep
704-
u0 = remake_buffer(prob, state_values(prob), keys(u0), values(u0))
705-
return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true), t0)
706-
end
707-
if !ispdep
708-
p = remake_buffer(prob, parameter_values(prob), keys(p), values(p))
709-
return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false), t0)
710-
end
711708

712709
varmap = merge(u0, p)
713-
u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap)
714-
for (k, v) in u0)
715-
p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap)
716-
for (k, v) in p)
710+
for (k, v) in u0
711+
u0[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap)
712+
end
713+
for (k, v) in p
714+
p[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap)
715+
end
717716
return remake_buffer(prob, state_values(prob), keys(u0), values(u0)),
718717
remake_buffer(prob, parameter_values(prob), keys(p), values(p))
719718
end

test/downstream/modelingtoolkit_remake.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,13 @@ end
219219
@test newoprob.ps[k] == [2.0, 3.0, 4.0, 5.0]
220220
@test newoprob[V] == [1.5, 2.5]
221221
end
222+
223+
@testset "remake with parameter dependent on observed" begin
224+
@variables x(t) y(t)
225+
@parameters p = x + y
226+
@mtkbuild sys = ODESystem([D(x) ~ x, p ~ x + y], t)
227+
prob = ODEProblem(sys, [x => 1.0, y => 2.0], (0.0, 1.0))
228+
@test prob.ps[p] 3.0
229+
prob2 = remake(prob; u0 = [y => 3.0], p = Dict())
230+
@test prob2.ps[p] 4.0
231+
end

0 commit comments

Comments
 (0)