Skip to content

Commit bd16440

Browse files
fix: properly handle values given to parameter dependencies in late_binding_update_u0_p
1 parent 1067cc0 commit bd16440

File tree

2 files changed

+39
-21
lines changed

2 files changed

+39
-21
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -782,7 +782,25 @@ function SciMLBase.late_binding_update_u0_p(
782782
newu0, newp = promote_u0_p(newu0, newp, t0)
783783

784784
# non-symbolic u0 updates initials...
785-
if !(eltype(u0) <: Pair)
785+
if eltype(u0) <: Pair
786+
syms = []
787+
vals = []
788+
allsyms = all_symbols(sys)
789+
for (k, v) in u0
790+
v === nothing && continue
791+
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
792+
if k isa Symbol
793+
k2 = symbol_to_symbolic(sys, k; allsyms)
794+
# if it is returned as-is, there is no match so skip it
795+
k2 === k && continue
796+
k = k2
797+
end
798+
is_parameter(sys, Initial(k)) || continue
799+
push!(syms, Initial(k))
800+
push!(vals, v)
801+
end
802+
newp = setp_oop(sys, syms)(newp, vals)
803+
else
786804
# if `p` is not provided or is symbolic
787805
p === missing || eltype(p) <: Pair || return newu0, newp
788806
(newu0 === nothing || isempty(newu0)) && return newu0, newp
@@ -795,27 +813,27 @@ function SciMLBase.late_binding_update_u0_p(
795813
throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))"))
796814
end
797815
newp = meta.set_initial_unknowns!(newp, newu0)
798-
return newu0, newp
799-
end
800-
801-
syms = []
802-
vals = []
803-
allsyms = all_symbols(sys)
804-
for (k, v) in u0
805-
v === nothing && continue
806-
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
807-
if k isa Symbol
808-
k2 = symbol_to_symbolic(sys, k; allsyms)
809-
# if it is returned as-is, there is no match so skip it
810-
k2 === k && continue
811-
k = k2
816+
end
817+
818+
if eltype(p) <: Pair
819+
syms = []
820+
vals = []
821+
for (k, v) in p
822+
v === nothing && continue
823+
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
824+
if k isa Symbol
825+
k2 = symbol_to_symbolic(sys, k; allsyms)
826+
# if it is returned as-is, there is no match so skip it
827+
k2 === k && continue
828+
k = k2
829+
end
830+
is_parameter(sys, Initial(k)) || continue
831+
push!(syms, Initial(k))
832+
push!(vals, v)
812833
end
813-
is_parameter(sys, Initial(k)) || continue
814-
push!(syms, Initial(k))
815-
push!(vals, v)
834+
newp = setp_oop(sys, syms)(newp, vals)
816835
end
817836

818-
newp = setp_oop(sys, syms)(newp, vals)
819837
return newu0, newp
820838
end
821839

test/initializationsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,9 +1253,9 @@ end
12531253
@test init(prob3)[x] 1.0
12541254
prob4 = remake(prob; p = [p => 1.0])
12551255
test_dummy_initialization_equation(prob4, x)
1256-
prob5 = remake(prob; p = [p => missing, q => 2.0])
1256+
prob5 = remake(prob; p = [p => missing, q => 4.0])
12571257
@test prob5.f.initialization_data !== nothing
1258-
@test init(prob5).ps[p] 1.0
1258+
@test init(prob5).ps[p] 2.0
12591259
end
12601260

12611261
@testset "Variables provided as symbols" begin

0 commit comments

Comments
 (0)