Skip to content

Commit a121f72

Browse files
Merge pull request #781 from AayushSabharwal/as/remake
fix: fix `remake` handling identical variables with different names
2 parents f5b6587 + 1d12433 commit a121f72

File tree

2 files changed

+76
-21
lines changed

2 files changed

+76
-21
lines changed

src/remake.jl

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -543,40 +543,76 @@ end
543543

544544
function fill_u0(prob, u0; defs = nothing, use_defaults = false)
545545
vsyms = variable_symbols(prob)
546-
if length(u0) == length(vsyms)
547-
return u0
546+
sym_to_idx = anydict()
547+
idx_to_sym = anydict()
548+
idx_to_val = anydict()
549+
for (k, v) in u0
550+
idx = variable_index(prob, k)
551+
idx === nothing && continue
552+
sym_to_idx[k] = idx
553+
idx_to_sym[idx] = k
554+
idx_to_val[idx] = v
548555
end
549-
newvals = anydict()
550556
for sym in vsyms
551-
varmap_has_var(u0, sym) && continue
552-
def = if defs === nothing || (defval = varmap_get(defs, sym)) === nothing ||
553-
(symbolic_type(defval) === NotSymbolic() && !use_defaults)
554-
nothing
555-
else
557+
haskey(sym_to_idx, sym) && continue
558+
idx = variable_index(prob, sym)
559+
haskey(idx_to_val, idx) && continue
560+
sym_to_idx[sym] = idx
561+
idx_to_sym[idx] = sym
562+
idx_to_val[idx] = if defs !== nothing &&
563+
(defval = varmap_get(defs, sym)) !== nothing &&
564+
(symbolic_type(defval) != NotSymbolic() || use_defaults)
556565
defval
566+
else
567+
getu(prob, sym)(prob)
557568
end
558-
newvals[sym] = @something def getu(prob, sym)(prob)
559569
end
560-
return merge(u0, newvals)
570+
newvals = anydict()
571+
for (idx, val) in idx_to_val
572+
newvals[idx_to_sym[idx]] = val
573+
end
574+
for (k, v) in u0
575+
haskey(sym_to_idx, k) && continue
576+
newvals[k] = v
577+
end
578+
return newvals
561579
end
562580

563581
function fill_p(prob, p; defs = nothing, use_defaults = false)
564-
psyms = parameter_symbols(prob)::Vector
565-
if length(p) == length(psyms)
566-
return p
582+
psyms = parameter_symbols(prob)
583+
sym_to_idx = anydict()
584+
idx_to_sym = anydict()
585+
idx_to_val = anydict()
586+
for (k, v) in p
587+
idx = parameter_index(prob, k)
588+
idx === nothing && continue
589+
sym_to_idx[k] = idx
590+
idx_to_sym[idx] = k
591+
idx_to_val[idx] = v
567592
end
568-
newvals = anydict()
569593
for sym in psyms
570-
varmap_has_var(p, sym) && continue
571-
def = if defs === nothing || (defval = varmap_get(defs, sym)) === nothing ||
572-
(symbolic_type(defval) === NotSymbolic() && !use_defaults)
573-
nothing
574-
else
594+
haskey(sym_to_idx, sym) && continue
595+
idx = parameter_index(prob, sym)
596+
haskey(idx_to_val, idx) && continue
597+
sym_to_idx[sym] = idx
598+
idx_to_sym[idx] = sym
599+
idx_to_val[idx] = if defs !== nothing &&
600+
(defval = varmap_get(defs, sym)) !== nothing &&
601+
(symbolic_type(defval) != NotSymbolic() || use_defaults)
575602
defval
603+
else
604+
getp(prob, sym)(prob)
576605
end
577-
newvals[sym] = @something def getp(prob, sym)(prob)
578606
end
579-
return merge(p, newvals)
607+
newvals = anydict()
608+
for (idx, val) in idx_to_val
609+
newvals[idx_to_sym[idx]] = val
610+
end
611+
for (k, v) in p
612+
haskey(sym_to_idx, k) && continue
613+
newvals[k] = v
614+
end
615+
return newvals
580616
end
581617

582618
function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false})

test/remake_tests.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,3 +295,22 @@ a = Remake_Test1(p = 1)
295295
@test @inferred remake(a, p = 2) == Remake_Test1(p = 2)
296296
@test @inferred remake(a, args = 1) == Remake_Test1(1, p = 1)
297297
@test @inferred remake(a, kwargs = (; a = 1)) == Remake_Test1(p = 1, a = 1)
298+
299+
@testset "fill_u0 and fill_p ignore identical variables with different names" begin
300+
sys = SymbolCache(Dict(:x => 1, :x2 => 1, :y => 2), Dict(:a => 1, :a2 => 1, :b => 2),
301+
:t; defaults = Dict(:x => 1, :y => 2, :a => 3, :b => 4))
302+
function foo(du, u, p, t)
303+
du .= u .* p
304+
end
305+
prob = ODEProblem(ODEFunction(foo; sys), [1.5, 2.5], (0.0, 1.0), [3.5, 4.5])
306+
u0 = Dict(:x2 => 2)
307+
newu0 = SciMLBase.fill_u0(prob, u0; defs = default_values(sys))
308+
@test length(newu0) == 2
309+
@test get(newu0, :x2, 0) == 2
310+
@test get(newu0, :y, 0) == 2.5
311+
p = Dict(:a2 => 3)
312+
newp = SciMLBase.fill_p(prob, p; defs = default_values(sys))
313+
@test length(newp) == 2
314+
@test get(newp, :a2, 0) == 3
315+
@test get(newp, :b, 0) == 4.5
316+
end

0 commit comments

Comments
 (0)