Skip to content

Commit 8b73ea2

Browse files
refactor: copy initials to u0 in solve/init instead of remake
1 parent 6aa661d commit 8b73ea2

File tree

2 files changed

+52
-28
lines changed

2 files changed

+52
-28
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -595,27 +595,6 @@ function SciMLBase.late_binding_update_u0_p(
595595

596596
initdata = prob.f.initialization_data
597597
meta = initdata === nothing ? nothing : initdata.metadata
598-
# If the user passes `p` to `remake` but not `u0` and `u0` isn't empty,
599-
# and if the system supports initialization (so it has initial parameters),
600-
# and if the initialization solves for `u0`,
601-
# THEN copy the values of `Initial`s to `newu0`.
602-
if u0 === missing
603-
if newu0 !== nothing && p !== missing && supports_initialization(sys) &&
604-
initdata !== nothing && initdata.initializeprobmap !== nothing
605-
getter = if meta isa InitializationMetadata
606-
meta.get_initial_unknowns
607-
else
608-
getu(sys, Initial.(unknowns(sys)))
609-
end
610-
if ArrayInterface.ismutable(newu0)
611-
copyto!(newu0, getter(newp))
612-
else
613-
T = StaticArrays.similar_type(newu0)
614-
newu0 = T(getter(newp))
615-
end
616-
end
617-
return newu0, newp
618-
end
619598

620599
# non-symbolic u0 updates initials...
621600
if !(eltype(u0) <: Pair)
@@ -669,6 +648,40 @@ function SciMLBase.late_binding_update_u0_p(
669648
return newu0, newp
670649
end
671650

651+
function DiffEqBase.get_updated_symbolic_problem(sys::AbstractSystem, prob; kw...)
652+
supports_initialization(sys) || return prob
653+
initdata = prob.f.initialization_data
654+
initdata === nothing && return prob
655+
656+
u0 = state_values(prob)
657+
u0 === nothing && return prob
658+
659+
p = parameter_values(prob)
660+
t0 = is_time_dependent(prob) ? current_time(prob) : nothing
661+
meta = initdata.metadata
662+
663+
getter = if meta isa InitializationMetadata
664+
meta.get_initial_unknowns
665+
else
666+
getu(sys, Initial.(unknowns(sys)))
667+
end
668+
669+
if p isa MTKParameters
670+
buffer = p.initials
671+
else
672+
buffer = p
673+
end
674+
675+
u0 = DiffEqBase.promote_u0(u0, buffer, t0)
676+
677+
if ArrayInterface.ismutable(u0)
678+
T = typeof(u0)
679+
else
680+
T = StaticArrays.similar_type(u0)
681+
end
682+
return remake(prob; u0 = T(getter(p)))
683+
end
684+
672685
"""
673686
$(TYPEDSIGNATURES)
674687

test/initializationsystem.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,7 +1513,7 @@ end
15131513
@inferred solve(prob)
15141514
end
15151515

1516-
@testset "Issue#3570: `Initial`s are copied to `u0` if `u0` not provided to `remake`" begin
1516+
@testset "Issue#3570, #3552: `Initial`s are copied to `u0` during `solve`/`init`" begin
15171517
@parameters g
15181518
@variables x(t) [state_priority = 10] y(t) λ(t)
15191519
eqs = [D(D(x)) ~ λ * x
@@ -1525,11 +1525,22 @@ end
15251525
pend, [x => (2 / 2)], (0.0, 1.5), [g => 1], guesses ==> 1, y => 2 / 2])
15261526
sol = solve(prob)
15271527

1528-
setter = setsym_oop(prob, [Initial(x)])
1529-
(u0, p) = setter(prob, [0.8])
1528+
@testset "`setsym_oop`" begin
1529+
setter = setsym_oop(prob, [Initial(x)])
1530+
(u0, p) = setter(prob, [0.8])
1531+
new_prob = remake(prob; u0, p, initializealg = BrownFullBasicInit())
1532+
new_sol = solve(new_prob)
1533+
@test new_sol[x, 1] 0.8
1534+
integ = init(new_prob)
1535+
@test integ[x] 0.8
1536+
end
15301537

1531-
new_prob = remake(prob; p, initializealg = BrownFullBasicInit())
1532-
@test new_prob[x] 0.8
1533-
new_sol = solve(new_prob)
1534-
@test new_sol[x, 1] 0.8
1538+
@testset "`setsym`" begin
1539+
@test prob.ps[Initial(x)] 2 / 2
1540+
prob.ps[Initial(x)] = 0.8
1541+
sol = solve(prob; initializealg = BrownFullBasicInit())
1542+
@test sol[x, 1] 0.8
1543+
integ = init(prob; initializealg = BrownFullBasicInit())
1544+
@test integ[x] 0.8
1545+
end
15351546
end

0 commit comments

Comments
 (0)