Skip to content

Commit a56253e

Browse files
fix: fix remake with u0 dependent on Symbol parameter
1 parent 0f8ec1f commit a56253e

File tree

3 files changed

+49
-5
lines changed

3 files changed

+49
-5
lines changed

src/remake.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,7 @@ end
567567

568568
function fill_u0(prob, u0; defs = nothing, use_defaults = false)
569569
vsyms = variable_symbols(prob)
570+
idx_to_vsym = anydict(variable_index(prob, sym) => sym for sym in vsyms)
570571
sym_to_idx = anydict()
571572
idx_to_sym = anydict()
572573
idx_to_val = anydict()
@@ -580,6 +581,8 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false)
580581
v = (v,)
581582
end
582583
for (kk, vv, ii) in zip(k, v, idx)
584+
sym_to_idx[kk] = ii
585+
kk = idx_to_vsym[ii]
583586
sym_to_idx[kk] = ii
584587
idx_to_sym[ii] = kk
585588
idx_to_val[ii] = vv
@@ -612,6 +615,7 @@ end
612615

613616
function fill_p(prob, p; defs = nothing, use_defaults = false)
614617
psyms = parameter_symbols(prob)
618+
idx_to_psym = anydict(parameter_index(prob, sym) => sym for sym in psyms)
615619
sym_to_idx = anydict()
616620
idx_to_sym = anydict()
617621
idx_to_val = anydict()
@@ -625,6 +629,8 @@ function fill_p(prob, p; defs = nothing, use_defaults = false)
625629
v = (v,)
626630
end
627631
for (kk, vv, ii) in zip(k, v, idx)
632+
sym_to_idx[kk] = ii
633+
kk = idx_to_psym[ii]
628634
sym_to_idx[kk] = ii
629635
idx_to_sym[ii] = kk
630636
idx_to_val[ii] = vv

test/downstream/modelingtoolkit_remake.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,13 @@ end
229229
prob2 = remake(prob; u0 = [y => 3.0], p = Dict())
230230
@test prob2.ps[p] 4.0
231231
end
232+
233+
@testset "u0 dependent on parameter given as Symbol" begin
234+
@variables x(t)
235+
@parameters p
236+
@mtkbuild sys = ODESystem([D(x) ~ x * p], t)
237+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0])
238+
@test prob.ps[p] 1.0
239+
prob2 = remake(prob; u0 = [x => p], p = [:p => 2.0])
240+
@test prob2[x] 2.0
241+
end

test/remake_tests.jl

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,26 +297,54 @@ a = Remake_Test1(p = 1)
297297
@test @inferred remake(a, kwargs = (; a = 1)) == Remake_Test1(p = 1, a = 1)
298298

299299
@testset "fill_u0 and fill_p ignore identical variables with different names" begin
300-
sys = SymbolCache(Dict(:x => 1, :x2 => 1, :y => 2), Dict(:a => 1, :a2 => 1, :b => 2),
301-
:t; defaults = Dict(:x => 1, :y => 2, :a => 3, :b => 4))
300+
struct SCWrapper{S}
301+
sc::S
302+
end
303+
SymbolicIndexingInterface.symbolic_container(s::SCWrapper) = s.sc
304+
function SymbolicIndexingInterface.is_variable(s::SCWrapper, i::Symbol)
305+
if i == :x2
306+
return is_variable(s.sc, :x)
307+
end
308+
is_variable(s.sc, i)
309+
end
310+
function SymbolicIndexingInterface.variable_index(s::SCWrapper, i::Symbol)
311+
if i == :x2
312+
return variable_index(s.sc, :x)
313+
end
314+
variable_index(s.sc, i)
315+
end
316+
function SymbolicIndexingInterface.is_parameter(s::SCWrapper, i::Symbol)
317+
if i == :a2
318+
return is_parameter(s.sc, :a)
319+
end
320+
is_parameter(s.sc, i)
321+
end
322+
function SymbolicIndexingInterface.parameter_index(s::SCWrapper, i::Symbol)
323+
if i == :a2
324+
return parameter_index(s.sc, :a)
325+
end
326+
parameter_index(s.sc, i)
327+
end
328+
sys = SCWrapper(SymbolCache(Dict(:x => 1, :y => 2), Dict(:a => 1, :b => 2),
329+
:t; defaults = Dict(:x => 1, :y => 2, :a => 3, :b => 4)))
302330
function foo(du, u, p, t)
303331
du .= u .* p
304332
end
305333
prob = ODEProblem(ODEFunction(foo; sys), [1.5, 2.5], (0.0, 1.0), [3.5, 4.5])
306334
u0 = Dict(:x2 => 2)
307335
newu0 = SciMLBase.fill_u0(prob, u0; defs = default_values(sys))
308336
@test length(newu0) == 2
309-
@test get(newu0, :x2, 0) == 2
337+
@test get(newu0, :x, 0) == 2
310338
@test get(newu0, :y, 0) == 2.5
311339
p = Dict(:a2 => 3)
312340
newp = SciMLBase.fill_p(prob, p; defs = default_values(sys))
313341
@test length(newp) == 2
314-
@test get(newp, :a2, 0) == 3
342+
@test get(newp, :a, 0) == 3
315343
@test get(newp, :b, 0) == 4.5
316344
end
317345

318346
@testset "value of `nothing` is ignored" begin
319-
sys = SymbolCache(Dict(:x => 1, :x2 => 1, :y => 2), Dict(:a => 1, :a2 => 1, :b => 2),
347+
sys = SymbolCache(Dict(:x => 1, :y => 2), Dict(:a => 1, :b => 2),
320348
:t; defaults = Dict(:x => 1, :y => 2, :a => 3, :b => 4))
321349
function foo(du, u, p, t)
322350
du .= u .* p

0 commit comments

Comments
 (0)