Skip to content

Commit babd239

Browse files
Merge pull request #837 from AayushSabharwal/as/remake-fix
fix: fix remake with u0 dependent on `Symbol` parameter
2 parents dd0da91 + 95c7fff commit babd239

File tree

5 files changed

+63
-9
lines changed

5 files changed

+63
-9
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,6 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
102102
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
103103
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
104104
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
105-
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
106-
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
107105
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
108106
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
109107
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
@@ -116,4 +114,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
116114
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
117115

118116
[targets]
119-
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "OrdinaryDiffEq", "ForwardDiff", "Tables"]
117+
test = ["Pkg", "Plots", "UnicodePlots", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "OrdinaryDiffEq", "ForwardDiff", "Tables"]

src/remake.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,7 @@ end
567567

568568
function fill_u0(prob, u0; defs = nothing, use_defaults = false)
569569
vsyms = variable_symbols(prob)
570+
idx_to_vsym = anydict(variable_index(prob, sym) => sym for sym in vsyms)
570571
sym_to_idx = anydict()
571572
idx_to_sym = anydict()
572573
idx_to_val = anydict()
@@ -580,6 +581,8 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false)
580581
v = (v,)
581582
end
582583
for (kk, vv, ii) in zip(k, v, idx)
584+
sym_to_idx[kk] = ii
585+
kk = idx_to_vsym[ii]
583586
sym_to_idx[kk] = ii
584587
idx_to_sym[ii] = kk
585588
idx_to_val[ii] = vv
@@ -612,6 +615,7 @@ end
612615

613616
function fill_p(prob, p; defs = nothing, use_defaults = false)
614617
psyms = parameter_symbols(prob)
618+
idx_to_psym = anydict(parameter_index(prob, sym) => sym for sym in psyms)
615619
sym_to_idx = anydict()
616620
idx_to_sym = anydict()
617621
idx_to_val = anydict()
@@ -625,6 +629,8 @@ function fill_p(prob, p; defs = nothing, use_defaults = false)
625629
v = (v,)
626630
end
627631
for (kk, vv, ii) in zip(k, v, idx)
632+
sym_to_idx[kk] = ii
633+
kk = idx_to_psym[ii]
628634
sym_to_idx[kk] = ii
629635
idx_to_sym[ii] = kk
630636
idx_to_val[ii] = vv
@@ -707,6 +713,9 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0)
707713
end
708714

709715
varmap = merge(u0, p)
716+
if is_time_dependent(prob)
717+
varmap[only(independent_variable_symbols(prob))] = t0
718+
end
710719
for (k, v) in u0
711720
u0[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap)
712721
end

test/downstream/modelingtoolkit_remake.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,22 @@ end
229229
prob2 = remake(prob; u0 = [y => 3.0], p = Dict())
230230
@test prob2.ps[p] 4.0
231231
end
232+
233+
@testset "u0 dependent on parameter given as Symbol" begin
234+
@variables x(t)
235+
@parameters p
236+
@mtkbuild sys = ODESystem([D(x) ~ x * p], t)
237+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0])
238+
@test prob.ps[p] 1.0
239+
prob2 = remake(prob; u0 = [x => p], p = [:p => 2.0])
240+
@test prob2[x] 2.0
241+
end
242+
243+
@testset "remake dependent on indepvar" begin
244+
@variables x(t)
245+
@parameters p
246+
@mtkbuild sys = ODESystem([D(x) ~ x * p], t)
247+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0])
248+
prob2 = remake(prob; u0 = [x => t + 3.0])
249+
@test prob2[x] 3.0
250+
end

test/downstream/solution_interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ end
293293
prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 5.0),
294294
[p => 0.5, q => 0.0, r => 1.0, s => 10.0, u => 4096.0])
295295
ss = SciMLBase.SavedSubsystem(sys, prob.p, [x, q, s, r])
296-
@test SciMLBase.get_saved_state_idxs(ss) == [xidx]
296+
@test SciMLBase.get_saved_state_idxs(ss) == [variable_index(sys, x)]
297297
sswf = SciMLBase.SavedSubsystemWithFallback(ss, sys)
298298
xidx = variable_index(sys, x)
299299
qidx = parameter_index(sys, q)

test/remake_tests.jl

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,26 +297,54 @@ a = Remake_Test1(p = 1)
297297
@test @inferred remake(a, kwargs = (; a = 1)) == Remake_Test1(p = 1, a = 1)
298298

299299
@testset "fill_u0 and fill_p ignore identical variables with different names" begin
300-
sys = SymbolCache(Dict(:x => 1, :x2 => 1, :y => 2), Dict(:a => 1, :a2 => 1, :b => 2),
301-
:t; defaults = Dict(:x => 1, :y => 2, :a => 3, :b => 4))
300+
struct SCWrapper{S}
301+
sc::S
302+
end
303+
SymbolicIndexingInterface.symbolic_container(s::SCWrapper) = s.sc
304+
function SymbolicIndexingInterface.is_variable(s::SCWrapper, i::Symbol)
305+
if i == :x2
306+
return is_variable(s.sc, :x)
307+
end
308+
is_variable(s.sc, i)
309+
end
310+
function SymbolicIndexingInterface.variable_index(s::SCWrapper, i::Symbol)
311+
if i == :x2
312+
return variable_index(s.sc, :x)
313+
end
314+
variable_index(s.sc, i)
315+
end
316+
function SymbolicIndexingInterface.is_parameter(s::SCWrapper, i::Symbol)
317+
if i == :a2
318+
return is_parameter(s.sc, :a)
319+
end
320+
is_parameter(s.sc, i)
321+
end
322+
function SymbolicIndexingInterface.parameter_index(s::SCWrapper, i::Symbol)
323+
if i == :a2
324+
return parameter_index(s.sc, :a)
325+
end
326+
parameter_index(s.sc, i)
327+
end
328+
sys = SCWrapper(SymbolCache(Dict(:x => 1, :y => 2), Dict(:a => 1, :b => 2),
329+
:t; defaults = Dict(:x => 1, :y => 2, :a => 3, :b => 4)))
302330
function foo(du, u, p, t)
303331
du .= u .* p
304332
end
305333
prob = ODEProblem(ODEFunction(foo; sys), [1.5, 2.5], (0.0, 1.0), [3.5, 4.5])
306334
u0 = Dict(:x2 => 2)
307335
newu0 = SciMLBase.fill_u0(prob, u0; defs = default_values(sys))
308336
@test length(newu0) == 2
309-
@test get(newu0, :x2, 0) == 2
337+
@test get(newu0, :x, 0) == 2
310338
@test get(newu0, :y, 0) == 2.5
311339
p = Dict(:a2 => 3)
312340
newp = SciMLBase.fill_p(prob, p; defs = default_values(sys))
313341
@test length(newp) == 2
314-
@test get(newp, :a2, 0) == 3
342+
@test get(newp, :a, 0) == 3
315343
@test get(newp, :b, 0) == 4.5
316344
end
317345

318346
@testset "value of `nothing` is ignored" begin
319-
sys = SymbolCache(Dict(:x => 1, :x2 => 1, :y => 2), Dict(:a => 1, :a2 => 1, :b => 2),
347+
sys = SymbolCache(Dict(:x => 1, :y => 2), Dict(:a => 1, :b => 2),
320348
:t; defaults = Dict(:x => 1, :y => 2, :a => 3, :b => 4))
321349
function foo(du, u, p, t)
322350
du .= u .* p

0 commit comments

Comments
 (0)