Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -630,13 +630,17 @@ with promoted types.

$(TYPEDFIELDS)
"""
struct ReconstructInitializeprob{G}
struct ReconstructInitializeprob{GP, GU}
"""
A function which when given the original problem and initialization problem, returns
the parameter object of the initialization problem with values copied from the
original.
"""
getter::G
pgetter::GP
"""
Given the original problem, return the `u0` of the initialization problem.
"""
ugetter::GU
end

"""
Expand Down Expand Up @@ -674,6 +678,7 @@ with values from `srcsys`.
function ReconstructInitializeprob(
srcsys::AbstractSystem, dstsys::AbstractSystem)
@assert is_initializesystem(dstsys)
ugetter = getu(srcsys, unknowns(dstsys))
if is_split(dstsys)
# if we call `getu` on this (and it were able to handle empty tuples) we get the
# fields of `MTKParameters` except caches.
Expand All @@ -693,7 +698,7 @@ function ReconstructInitializeprob(
end
end
getters = (tunable_getter, Returns(SizedVector{0, Float64}()), rest_getters...)
getter = let getters = getters
pgetter = let getters = getters
function _getter(valp, initprob)
oldcache = parameter_values(initprob).caches
MTKParameters(getters[1](valp), getters[2](valp), getters[3](valp),
Expand All @@ -703,13 +708,13 @@ function ReconstructInitializeprob(
end
else
syms = parameters(dstsys)
getter = let inner = concrete_getu(srcsys, syms)
pgetter = let inner = concrete_getu(srcsys, syms)
function _getter2(valp, initprob)
inner(valp)
end
end
end
return ReconstructInitializeprob(getter)
return ReconstructInitializeprob(pgetter, ugetter)
end

"""
Expand All @@ -719,7 +724,7 @@ Copy values from `srcvalp` to `dstvalp`. Returns the new `u0` and `p`.
"""
function (rip::ReconstructInitializeprob)(srcvalp, dstvalp)
# copy parameters
newp = rip.getter(srcvalp, dstvalp)
newp = rip.pgetter(srcvalp, dstvalp)
# no `u0`, so no type-promotion
if state_values(dstvalp) === nothing
return nothing, newp
Expand All @@ -735,11 +740,10 @@ function (rip::ReconstructInitializeprob)(srcvalp, dstvalp)
elseif !isempty(newp)
T = promote_type(eltype(newp), T)
end
u0 = rip.ugetter(srcvalp)
# and the eltype of the destination u0
if T == eltype(state_values(dstvalp))
u0 = state_values(dstvalp)
elseif T != Union{}
u0 = T.(state_values(dstvalp))
if T != eltype(u0) && T != Union{}
u0 = T.(u0)
end
# apply the promotion to tunables portion
buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp)
Expand Down Expand Up @@ -911,11 +915,13 @@ function maybe_build_initialization_problem(
punknowns = [p
for p in all_variable_symbols(initializeprob)
if is_parameter(sys, p)]
if isempty(punknowns)
if initializeprobmap === nothing && isempty(punknowns)
initializeprobpmap = nothing
else
getpunknowns = getu(initializeprob, punknowns)
setpunknowns = setp(sys, punknowns)
allsyms = all_symbols(initializeprob)
initdvs = filter(x -> any(isequal(x), allsyms), unknowns(sys))
getpunknowns = getu(initializeprob, [punknowns; initdvs])
setpunknowns = setp(sys, [punknowns; Initial.(initdvs)])
initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
end

Expand Down
2 changes: 1 addition & 1 deletion test/basic_transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ end
@independent_variables t_units [unit = u"s"]
D_units = Differential(t_units)
@variables x(t_units) [unit = u"m"] y(t_units) [unit = u"m"]
@parameters g = 9.81 [unit = u"m * s^-2"] # gravitational acceleration
@parameters g=9.81 [unit = u"m * s^-2"] # gravitational acceleration
Mt = ODESystem([D_units(D_units(y)) ~ -g, D_units(D_units(x)) ~ 0], t_units; name = :M) # gives (x, y) as function of t, ...
Mx = change_independent_variable(Mt, x; add_old_diff = true) # ... but we want y as a function of x
Mx = structural_simplify(Mx; allow_symbolic = true)
Expand Down
41 changes: 38 additions & 3 deletions test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1194,11 +1194,11 @@ end
@test integ[x] ≈ 1 / cbrt(3)
@test integ[y] ≈ 2 / cbrt(3)
@test integ.ps[p] == 1.0
@test integ.ps[q] ≈ 3 / cbrt(3)
@test integ.ps[q] ≈ 3 / cbrt(3) atol=1e-5
prob2 = remake(prob; u0 = [y => 3x], p = [q => 2x])
integ2 = init(prob2)
@test integ2[x] ≈ cbrt(3 / 28)
@test integ2[y] ≈ 3cbrt(3 / 28)
@test integ2[x] ≈ cbrt(3 / 28) atol=1e-5
@test integ2[y] ≈ 3cbrt(3 / 28) atol=1e-5
@test integ2.ps[p] == 1.0
@test integ2.ps[q] ≈ 2cbrt(3 / 28)
end
Expand Down Expand Up @@ -1563,3 +1563,38 @@ end
@test integ[x] ≈ 0.8
end
end

@testset "Initialization copies solved `u0` to `p`" begin
@parameters σ ρ β A[1:3]
@variables x(t) y(t) z(t) w(t) w2(t)
eqs = [D(D(x)) ~ σ * (y - x),
D(y) ~ x * (ρ - z) - y,
D(z) ~ x * y - β * z,
w ~ x + y + z + 2 * β,
0 ~ x^2 + y^2 - w2^2
]

@mtkbuild sys = ODESystem(eqs, t)

u0 = [D(x) => 2.0,
x => 1.0,
y => 0.0,
z => 0.0]

p = [σ => 28.0,
ρ => 10.0,
β => 8 / 3]

tspan = (0.0, 100.0)
getter = getsym(sys, Initial.(unknowns(sys)))
prob = ODEProblem(sys, u0, tspan, p; guesses = [w2 => 3.0])
new_u0, new_p, _ = SciMLBase.get_initial_values(
prob, prob, prob.f, SciMLBase.OverrideInit(), Val(true);
nlsolve_alg = NewtonRaphson(), abstol = 1e-6, reltol = 1e-3)
@test getter(prob) != getter(new_p)
@test getter(new_p) == new_u0
_prob = remake(prob, u0 = new_u0, p = new_p)
sol = solve(_prob; initializealg = CheckInit())
@test SciMLBase.successful_retcode(sol)
@test sol.u[1] ≈ new_u0
end
Loading