diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 47bf3c678d..4cba787852 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -630,13 +630,17 @@ with promoted types. $(TYPEDFIELDS) """ -struct ReconstructInitializeprob{G} +struct ReconstructInitializeprob{GP, GU} """ A function which when given the original problem and initialization problem, returns the parameter object of the initialization problem with values copied from the original. """ - getter::G + pgetter::GP + """ + Given the original problem, return the `u0` of the initialization problem. + """ + ugetter::GU end """ @@ -674,6 +678,7 @@ with values from `srcsys`. function ReconstructInitializeprob( srcsys::AbstractSystem, dstsys::AbstractSystem) @assert is_initializesystem(dstsys) + ugetter = getu(srcsys, unknowns(dstsys)) if is_split(dstsys) # if we call `getu` on this (and it were able to handle empty tuples) we get the # fields of `MTKParameters` except caches. @@ -693,7 +698,7 @@ function ReconstructInitializeprob( end end getters = (tunable_getter, Returns(SizedVector{0, Float64}()), rest_getters...) - getter = let getters = getters + pgetter = let getters = getters function _getter(valp, initprob) oldcache = parameter_values(initprob).caches MTKParameters(getters[1](valp), getters[2](valp), getters[3](valp), @@ -703,13 +708,13 @@ function ReconstructInitializeprob( end else syms = parameters(dstsys) - getter = let inner = concrete_getu(srcsys, syms) + pgetter = let inner = concrete_getu(srcsys, syms) function _getter2(valp, initprob) inner(valp) end end end - return ReconstructInitializeprob(getter) + return ReconstructInitializeprob(pgetter, ugetter) end """ @@ -719,7 +724,7 @@ Copy values from `srcvalp` to `dstvalp`. Returns the new `u0` and `p`. """ function (rip::ReconstructInitializeprob)(srcvalp, dstvalp) # copy parameters - newp = rip.getter(srcvalp, dstvalp) + newp = rip.pgetter(srcvalp, dstvalp) # no `u0`, so no type-promotion if state_values(dstvalp) === nothing return nothing, newp @@ -735,11 +740,10 @@ function (rip::ReconstructInitializeprob)(srcvalp, dstvalp) elseif !isempty(newp) T = promote_type(eltype(newp), T) end + u0 = rip.ugetter(srcvalp) # and the eltype of the destination u0 - if T == eltype(state_values(dstvalp)) - u0 = state_values(dstvalp) - elseif T != Union{} - u0 = T.(state_values(dstvalp)) + if T != eltype(u0) && T != Union{} + u0 = T.(u0) end # apply the promotion to tunables portion buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp) @@ -911,11 +915,13 @@ function maybe_build_initialization_problem( punknowns = [p for p in all_variable_symbols(initializeprob) if is_parameter(sys, p)] - if isempty(punknowns) + if initializeprobmap === nothing && isempty(punknowns) initializeprobpmap = nothing else - getpunknowns = getu(initializeprob, punknowns) - setpunknowns = setp(sys, punknowns) + allsyms = all_symbols(initializeprob) + initdvs = filter(x -> any(isequal(x), allsyms), unknowns(sys)) + getpunknowns = getu(initializeprob, [punknowns; initdvs]) + setpunknowns = setp(sys, [punknowns; Initial.(initdvs)]) initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns) end diff --git a/test/basic_transformations.jl b/test/basic_transformations.jl index 183feca6d2..b593deb345 100644 --- a/test/basic_transformations.jl +++ b/test/basic_transformations.jl @@ -220,7 +220,7 @@ end @independent_variables t_units [unit = u"s"] D_units = Differential(t_units) @variables x(t_units) [unit = u"m"] y(t_units) [unit = u"m"] - @parameters g = 9.81 [unit = u"m * s^-2"] # gravitational acceleration + @parameters g=9.81 [unit = u"m * s^-2"] # gravitational acceleration Mt = ODESystem([D_units(D_units(y)) ~ -g, D_units(D_units(x)) ~ 0], t_units; name = :M) # gives (x, y) as function of t, ... Mx = change_independent_variable(Mt, x; add_old_diff = true) # ... but we want y as a function of x Mx = structural_simplify(Mx; allow_symbolic = true) diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index 5c36fcba3e..4274f31cc7 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -1194,13 +1194,13 @@ end @test integ[x] ≈ 1 / cbrt(3) @test integ[y] ≈ 2 / cbrt(3) @test integ.ps[p] == 1.0 - @test integ.ps[q] ≈ 3 / cbrt(3) + @test integ.ps[q] ≈ 3 / cbrt(3) atol=1e-5 prob2 = remake(prob; u0 = [y => 3x], p = [q => 2x]) integ2 = init(prob2) - @test integ2[x] ≈ cbrt(3 / 28) - @test integ2[y] ≈ 3cbrt(3 / 28) + @test integ2[x] ≈ cbrt(3 / 28) atol=1e-5 + @test integ2[y] ≈ 3cbrt(3 / 28) atol=1e-5 @test integ2.ps[p] == 1.0 - @test integ2.ps[q] ≈ 2cbrt(3 / 28) + @test integ2.ps[q] ≈ 2cbrt(3 / 28) atol=1e-5 end function test_dummy_initialization_equation(prob, var) @@ -1563,3 +1563,38 @@ end @test integ[x] ≈ 0.8 end end + +@testset "Initialization copies solved `u0` to `p`" begin + @parameters σ ρ β A[1:3] + @variables x(t) y(t) z(t) w(t) w2(t) + eqs = [D(D(x)) ~ σ * (y - x), + D(y) ~ x * (ρ - z) - y, + D(z) ~ x * y - β * z, + w ~ x + y + z + 2 * β, + 0 ~ x^2 + y^2 - w2^2 + ] + + @mtkbuild sys = ODESystem(eqs, t) + + u0 = [D(x) => 2.0, + x => 1.0, + y => 0.0, + z => 0.0] + + p = [σ => 28.0, + ρ => 10.0, + β => 8 / 3] + + tspan = (0.0, 100.0) + getter = getsym(sys, Initial.(unknowns(sys))) + prob = ODEProblem(sys, u0, tspan, p; guesses = [w2 => 3.0]) + new_u0, new_p, _ = SciMLBase.get_initial_values( + prob, prob, prob.f, SciMLBase.OverrideInit(), Val(true); + nlsolve_alg = NewtonRaphson(), abstol = 1e-6, reltol = 1e-3) + @test getter(prob) != getter(new_p) + @test getter(new_p) == new_u0 + _prob = remake(prob, u0 = new_u0, p = new_p) + sol = solve(_prob; initializealg = CheckInit()) + @test SciMLBase.successful_retcode(sol) + @test sol.u[1] ≈ new_u0 +end