Skip to content

Commit 3d7eacf

Browse files
refactor: copy initials to u0 in solve/init instead of remake
1 parent c385c11 commit 3d7eacf

File tree

2 files changed

+52
-28
lines changed

2 files changed

+52
-28
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -591,27 +591,6 @@ function SciMLBase.late_binding_update_u0_p(
591591

592592
initdata = prob.f.initialization_data
593593
meta = initdata === nothing ? nothing : initdata.metadata
594-
# If the user passes `p` to `remake` but not `u0` and `u0` isn't empty,
595-
# and if the system supports initialization (so it has initial parameters),
596-
# and if the initialization solves for `u0`,
597-
# THEN copy the values of `Initial`s to `newu0`.
598-
if u0 === missing
599-
if newu0 !== nothing && p !== missing && supports_initialization(sys) &&
600-
initdata !== nothing && initdata.initializeprobmap !== nothing
601-
getter = if meta isa InitializationMetadata
602-
meta.get_initial_unknowns
603-
else
604-
getu(sys, Initial.(unknowns(sys)))
605-
end
606-
if ArrayInterface.ismutable(newu0)
607-
copyto!(newu0, getter(newp))
608-
else
609-
T = StaticArrays.similar_type(newu0)
610-
newu0 = T(getter(newp))
611-
end
612-
end
613-
return newu0, newp
614-
end
615594

616595
# non-symbolic u0 updates initials...
617596
if !(eltype(u0) <: Pair)
@@ -665,6 +644,40 @@ function SciMLBase.late_binding_update_u0_p(
665644
return newu0, newp
666645
end
667646

647+
function DiffEqBase.get_updated_symbolic_problem(sys::AbstractSystem, prob; kw...)
648+
supports_initialization(sys) || return prob
649+
initdata = prob.f.initialization_data
650+
initdata === nothing && return prob
651+
652+
u0 = state_values(prob)
653+
u0 === nothing && return prob
654+
655+
p = parameter_values(prob)
656+
t0 = is_time_dependent(prob) ? current_time(prob) : nothing
657+
meta = initdata.metadata
658+
659+
getter = if meta isa InitializationMetadata
660+
meta.get_initial_unknowns
661+
else
662+
getu(sys, Initial.(unknowns(sys)))
663+
end
664+
665+
if p isa MTKParameters
666+
buffer = p.initials
667+
else
668+
buffer = p
669+
end
670+
671+
u0 = DiffEqBase.promote_u0(u0, buffer, t0)
672+
673+
if ArrayInterface.ismutable(u0)
674+
T = typeof(u0)
675+
else
676+
T = StaticArrays.similar_type(u0)
677+
end
678+
return remake(prob; u0 = T(getter(p)))
679+
end
680+
668681
"""
669682
$(TYPEDSIGNATURES)
670683

test/initializationsystem.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1508,7 +1508,7 @@ end
15081508
@inferred solve(prob)
15091509
end
15101510

1511-
@testset "Issue#3570: `Initial`s are copied to `u0` if `u0` not provided to `remake`" begin
1511+
@testset "Issue#3570, #3552: `Initial`s are copied to `u0` during `solve`/`init`" begin
15121512
@parameters g
15131513
@variables x(t) [state_priority = 10] y(t) λ(t)
15141514
eqs = [D(D(x)) ~ λ * x
@@ -1520,11 +1520,22 @@ end
15201520
pend, [x => (2 / 2)], (0.0, 1.5), [g => 1], guesses ==> 1, y => 2 / 2])
15211521
sol = solve(prob)
15221522

1523-
setter = setsym_oop(prob, [Initial(x)])
1524-
(u0, p) = setter(prob, [0.8])
1523+
@testset "`setsym_oop`" begin
1524+
setter = setsym_oop(prob, [Initial(x)])
1525+
(u0, p) = setter(prob, [0.8])
1526+
new_prob = remake(prob; u0, p, initializealg = BrownFullBasicInit())
1527+
new_sol = solve(new_prob)
1528+
@test new_sol[x, 1] 0.8
1529+
integ = init(new_prob)
1530+
@test integ[x] 0.8
1531+
end
15251532

1526-
new_prob = remake(prob; p, initializealg = BrownFullBasicInit())
1527-
@test new_prob[x] 0.8
1528-
new_sol = solve(new_prob)
1529-
@test new_sol[x, 1] 0.8
1533+
@testset "`setsym`" begin
1534+
@test prob.ps[Initial(x)] 2 / 2
1535+
prob.ps[Initial(x)] = 0.8
1536+
sol = solve(prob; initializealg = BrownFullBasicInit())
1537+
@test sol[x, 1] 0.8
1538+
integ = init(prob; initializealg = BrownFullBasicInit())
1539+
@test integ[x] 0.8
1540+
end
15301541
end

0 commit comments

Comments
 (0)