Skip to content

Commit c449c50

Browse files
Merge pull request #3599 from AayushSabharwal/as/initprobpmap-copy-initials
fix: allow solving initialization separately
2 parents 4a8f1ed + 0c56d86 commit c449c50

File tree

3 files changed

+59
-18
lines changed

3 files changed

+59
-18
lines changed

src/systems/problem_utils.jl

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -630,13 +630,17 @@ with promoted types.
630630
631631
$(TYPEDFIELDS)
632632
"""
633-
struct ReconstructInitializeprob{G}
633+
struct ReconstructInitializeprob{GP, GU}
634634
"""
635635
A function which when given the original problem and initialization problem, returns
636636
the parameter object of the initialization problem with values copied from the
637637
original.
638638
"""
639-
getter::G
639+
pgetter::GP
640+
"""
641+
Given the original problem, return the `u0` of the initialization problem.
642+
"""
643+
ugetter::GU
640644
end
641645

642646
"""
@@ -674,6 +678,7 @@ with values from `srcsys`.
674678
function ReconstructInitializeprob(
675679
srcsys::AbstractSystem, dstsys::AbstractSystem)
676680
@assert is_initializesystem(dstsys)
681+
ugetter = getu(srcsys, unknowns(dstsys))
677682
if is_split(dstsys)
678683
# if we call `getu` on this (and it were able to handle empty tuples) we get the
679684
# fields of `MTKParameters` except caches.
@@ -693,7 +698,7 @@ function ReconstructInitializeprob(
693698
end
694699
end
695700
getters = (tunable_getter, Returns(SizedVector{0, Float64}()), rest_getters...)
696-
getter = let getters = getters
701+
pgetter = let getters = getters
697702
function _getter(valp, initprob)
698703
oldcache = parameter_values(initprob).caches
699704
MTKParameters(getters[1](valp), getters[2](valp), getters[3](valp),
@@ -703,13 +708,13 @@ function ReconstructInitializeprob(
703708
end
704709
else
705710
syms = parameters(dstsys)
706-
getter = let inner = concrete_getu(srcsys, syms)
711+
pgetter = let inner = concrete_getu(srcsys, syms)
707712
function _getter2(valp, initprob)
708713
inner(valp)
709714
end
710715
end
711716
end
712-
return ReconstructInitializeprob(getter)
717+
return ReconstructInitializeprob(pgetter, ugetter)
713718
end
714719

715720
"""
@@ -719,7 +724,7 @@ Copy values from `srcvalp` to `dstvalp`. Returns the new `u0` and `p`.
719724
"""
720725
function (rip::ReconstructInitializeprob)(srcvalp, dstvalp)
721726
# copy parameters
722-
newp = rip.getter(srcvalp, dstvalp)
727+
newp = rip.pgetter(srcvalp, dstvalp)
723728
# no `u0`, so no type-promotion
724729
if state_values(dstvalp) === nothing
725730
return nothing, newp
@@ -735,11 +740,10 @@ function (rip::ReconstructInitializeprob)(srcvalp, dstvalp)
735740
elseif !isempty(newp)
736741
T = promote_type(eltype(newp), T)
737742
end
743+
u0 = rip.ugetter(srcvalp)
738744
# and the eltype of the destination u0
739-
if T == eltype(state_values(dstvalp))
740-
u0 = state_values(dstvalp)
741-
elseif T != Union{}
742-
u0 = T.(state_values(dstvalp))
745+
if T != eltype(u0) && T != Union{}
746+
u0 = T.(u0)
743747
end
744748
# apply the promotion to tunables portion
745749
buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp)
@@ -911,11 +915,13 @@ function maybe_build_initialization_problem(
911915
punknowns = [p
912916
for p in all_variable_symbols(initializeprob)
913917
if is_parameter(sys, p)]
914-
if isempty(punknowns)
918+
if initializeprobmap === nothing && isempty(punknowns)
915919
initializeprobpmap = nothing
916920
else
917-
getpunknowns = getu(initializeprob, punknowns)
918-
setpunknowns = setp(sys, punknowns)
921+
allsyms = all_symbols(initializeprob)
922+
initdvs = filter(x -> any(isequal(x), allsyms), unknowns(sys))
923+
getpunknowns = getu(initializeprob, [punknowns; initdvs])
924+
setpunknowns = setp(sys, [punknowns; Initial.(initdvs)])
919925
initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
920926
end
921927

test/basic_transformations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ end
220220
@independent_variables t_units [unit = u"s"]
221221
D_units = Differential(t_units)
222222
@variables x(t_units) [unit = u"m"] y(t_units) [unit = u"m"]
223-
@parameters g = 9.81 [unit = u"m * s^-2"] # gravitational acceleration
223+
@parameters g=9.81 [unit = u"m * s^-2"] # gravitational acceleration
224224
Mt = ODESystem([D_units(D_units(y)) ~ -g, D_units(D_units(x)) ~ 0], t_units; name = :M) # gives (x, y) as function of t, ...
225225
Mx = change_independent_variable(Mt, x; add_old_diff = true) # ... but we want y as a function of x
226226
Mx = structural_simplify(Mx; allow_symbolic = true)

test/initializationsystem.jl

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,13 +1194,13 @@ end
11941194
@test integ[x] 1 / cbrt(3)
11951195
@test integ[y] 2 / cbrt(3)
11961196
@test integ.ps[p] == 1.0
1197-
@test integ.ps[q] 3 / cbrt(3)
1197+
@test integ.ps[q] 3 / cbrt(3) atol=1e-5
11981198
prob2 = remake(prob; u0 = [y => 3x], p = [q => 2x])
11991199
integ2 = init(prob2)
1200-
@test integ2[x] cbrt(3 / 28)
1201-
@test integ2[y] 3cbrt(3 / 28)
1200+
@test integ2[x] cbrt(3 / 28) atol=1e-5
1201+
@test integ2[y] 3cbrt(3 / 28) atol=1e-5
12021202
@test integ2.ps[p] == 1.0
1203-
@test integ2.ps[q] 2cbrt(3 / 28)
1203+
@test integ2.ps[q] 2cbrt(3 / 28) atol=1e-5
12041204
end
12051205

12061206
function test_dummy_initialization_equation(prob, var)
@@ -1563,3 +1563,38 @@ end
15631563
@test integ[x] 0.8
15641564
end
15651565
end
1566+
1567+
@testset "Initialization copies solved `u0` to `p`" begin
1568+
@parameters σ ρ β A[1:3]
1569+
@variables x(t) y(t) z(t) w(t) w2(t)
1570+
eqs = [D(D(x)) ~ σ * (y - x),
1571+
D(y) ~ x *- z) - y,
1572+
D(z) ~ x * y - β * z,
1573+
w ~ x + y + z + 2 * β,
1574+
0 ~ x^2 + y^2 - w2^2
1575+
]
1576+
1577+
@mtkbuild sys = ODESystem(eqs, t)
1578+
1579+
u0 = [D(x) => 2.0,
1580+
x => 1.0,
1581+
y => 0.0,
1582+
z => 0.0]
1583+
1584+
p ==> 28.0,
1585+
ρ => 10.0,
1586+
β => 8 / 3]
1587+
1588+
tspan = (0.0, 100.0)
1589+
getter = getsym(sys, Initial.(unknowns(sys)))
1590+
prob = ODEProblem(sys, u0, tspan, p; guesses = [w2 => 3.0])
1591+
new_u0, new_p, _ = SciMLBase.get_initial_values(
1592+
prob, prob, prob.f, SciMLBase.OverrideInit(), Val(true);
1593+
nlsolve_alg = NewtonRaphson(), abstol = 1e-6, reltol = 1e-3)
1594+
@test getter(prob) != getter(new_p)
1595+
@test getter(new_p) == new_u0
1596+
_prob = remake(prob, u0 = new_u0, p = new_p)
1597+
sol = solve(_prob; initializealg = CheckInit())
1598+
@test SciMLBase.successful_retcode(sol)
1599+
@test sol.u[1] new_u0
1600+
end

0 commit comments

Comments
 (0)