diff --git a/Project.toml b/Project.toml index 30580d961b..5b5b2eeec8 100644 --- a/Project.toml +++ b/Project.toml @@ -139,7 +139,7 @@ RecursiveArrayTools = "3.26" Reexport = "0.2, 1" RuntimeGeneratedFunctions = "0.5.9" SCCNonlinearSolve = "1.0.0" -SciMLBase = "2.75" +SciMLBase = "2.84" SciMLStructures = "1.7" Serialization = "1" Setfield = "0.7, 0.8, 1" @@ -150,7 +150,7 @@ StaticArrays = "0.10, 0.11, 0.12, 1.0" StochasticDelayDiffEq = "1.8.1" StochasticDiffEq = "6.72.1" SymbolicIndexingInterface = "0.3.39" -SymbolicUtils = "3.25.1" +SymbolicUtils = "3.26.1" Symbolics = "6.37" URIs = "1" UnPack = "0.1, 1.0" diff --git a/docs/src/basics/Events.md b/docs/src/basics/Events.md index 23e1e6d7d1..3a76f478f1 100644 --- a/docs/src/basics/Events.md +++ b/docs/src/basics/Events.md @@ -126,8 +126,8 @@ prob = ODEProblem(ball, Pair[], tspan) sol = solve(prob, Tsit5()) @assert 0 <= minimum(sol[x]) <= 1e-10 # the ball never went through the floor but got very close -@assert minimum(sol[y]) > -1.5 # check wall conditions -@assert maximum(sol[y]) < 1.5 # check wall conditions +@assert minimum(sol[y]) >= -1.5 # check wall conditions +@assert maximum(sol[y]) <= 1.5 # check wall conditions tv = sort([LinRange(0, 10, 200); sol.t]) plot(sol(tv)[y], sol(tv)[x], line_z = tv) diff --git a/docs/src/tutorials/disturbance_modeling.md b/docs/src/tutorials/disturbance_modeling.md index b77a73b0c1..db8d926498 100644 --- a/docs/src/tutorials/disturbance_modeling.md +++ b/docs/src/tutorials/disturbance_modeling.md @@ -224,7 +224,7 @@ To see full examples that perform state estimation with ModelingToolkit models, Pages = ["disturbance_modeling.md"] ``` -```@autodocs +```@autodocs; canonical = false Modules = [ModelingToolkit] Pages = ["systems/analysis_points.jl"] Order = [:function, :type] diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 29338d0722..782d1b5229 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -915,6 +915,7 @@ for prop in [:eqs :substitutions :metadata :gui_metadata + :is_initializesystem :discrete_subsystems :parameter_dependencies :assertions diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 62ddd12a08..67b47a3964 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -1457,12 +1457,12 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem, elseif isempty(u0map) && get_initializesystem(sys) === nothing isys = generate_initializesystem( sys; initialization_eqs, check_units, pmap = parammap, - guesses, extra_metadata = (; use_scc), algebraic_only) + guesses, algebraic_only) simplify_system = true else isys = generate_initializesystem( sys; u0map, initialization_eqs, check_units, - pmap = parammap, guesses, extra_metadata = (; use_scc), algebraic_only) + pmap = parammap, guesses, algebraic_only) simplify_system = true end @@ -1477,12 +1477,6 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem, isys = structural_simplify(isys; fully_determined) end - meta = get_metadata(isys) - if meta isa InitializationSystemMetadata - @set! isys.metadata.oop_reconstruct_u0_p = ReconstructInitializeprob( - sys, isys) - end - ts = get_tearing_state(isys) unassigned_vars = StructuralTransformations.singular_check(ts) if warn_initialize_determined && !isempty(unassigned_vars) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 47a784c00b..d0b687c212 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -502,7 +502,7 @@ function reorder_parameters( end end -function reorder_parameters(ic::IndexCache, ps; drop_missing = false) +function reorder_parameters(ic::IndexCache, ps; drop_missing = false, flatten = true) isempty(ps) && return () param_buf = if ic.tunable_buffer_size.length == 0 () @@ -555,20 +555,37 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false) end end - result = broadcast.( - unwrap, ( - param_buf..., initials_buf..., disc_buf..., const_buf..., nonnumeric_buf...)) + param_buf = broadcast.(unwrap, param_buf) + initials_buf = broadcast.(unwrap, initials_buf) + disc_buf = broadcast.(unwrap, disc_buf) + const_buf = broadcast.(unwrap, const_buf) + nonnumeric_buf = broadcast.(unwrap, nonnumeric_buf) + if drop_missing - result = map(result) do buf - filter(buf) do sym - return !isequal(sym, unwrap(variable(:DEF))) - end + filterer = !isequal(unwrap(variable(:DEF))) + param_buf = filter.(filterer, param_buf) + initials_buf = filter.(filterer, initials_buf) + disc_buf = filter.(filterer, disc_buf) + const_buf = filter.(filterer, const_buf) + nonnumeric_buf = filter.(filterer, nonnumeric_buf) + end + + if flatten + result = ( + param_buf..., initials_buf..., disc_buf..., const_buf..., nonnumeric_buf...) + if all(isempty, result) + return () end + return result + else + if isempty(param_buf) + param_buf = ((),) + end + if isempty(initials_buf) + initials_buf = ((),) + end + return (param_buf..., initials_buf..., disc_buf, const_buf, nonnumeric_buf) end - if all(isempty, result) - return () - end - return result end # Given a parameter index, find the index of the buffer it is in when diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index ec25b9b660..f83ea6006f 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -11,7 +11,7 @@ function generate_initializesystem(sys::AbstractTimeDependentSystem; default_dd_guess = Bool(0), algebraic_only = false, check_units = true, check_defguess = false, - name = nameof(sys), extra_metadata = (;), kwargs...) + name = nameof(sys), kwargs...) eqs = equations(sys) if !(eqs isa Vector{Equation}) eqs = Equation[x for x in eqs if x isa Equation] @@ -143,9 +143,7 @@ function generate_initializesystem(sys::AbstractTimeDependentSystem; for k in keys(defs) defs[k] = substitute(defs[k], paramsubs) end - meta = InitializationSystemMetadata( - anydict(u0map), anydict(pmap), additional_guesses, - additional_initialization_eqs, extra_metadata, nothing) + return NonlinearSystem(eqs_ics, vars, pars; @@ -153,7 +151,7 @@ function generate_initializesystem(sys::AbstractTimeDependentSystem; checks = check_units, parameter_dependencies = new_parameter_deps, name, - metadata = meta, + is_initializesystem = true, kwargs...) end @@ -169,7 +167,7 @@ function generate_initializesystem(sys::AbstractTimeIndependentSystem; guesses = Dict(), algebraic_only = false, check_units = true, check_defguess = false, - name = nameof(sys), extra_metadata = (;), kwargs...) + name = nameof(sys), kwargs...) eqs = equations(sys) trueobs, eqs = unhack_observed(observed(sys), eqs) vars = unique([unknowns(sys); getfield.(trueobs, :lhs)]) @@ -244,9 +242,7 @@ function generate_initializesystem(sys::AbstractTimeIndependentSystem; for k in keys(defs) defs[k] = substitute(defs[k], paramsubs) end - meta = InitializationSystemMetadata( - anydict(u0map), anydict(pmap), additional_guesses, - additional_initialization_eqs, extra_metadata, nothing) + return NonlinearSystem(eqs_ics, vars, pars; @@ -254,7 +250,7 @@ function generate_initializesystem(sys::AbstractTimeIndependentSystem; checks = check_units, parameter_dependencies = new_parameter_deps, name, - metadata = meta, + is_initializesystem = true, kwargs...) end @@ -436,64 +432,6 @@ function _has_delays(sys::AbstractSystem, ex, banned) return any(x -> _has_delays(sys, x, banned), args) end -struct ReconstructInitializeprob - getter::Any - setter::Any -end - -function ReconstructInitializeprob( - srcsys::AbstractSystem, dstsys::AbstractSystem) - syms = reduce( - vcat, reorder_parameters(dstsys, parameters(dstsys)); - init = []) - getter = getu(srcsys, syms) - setter = setp_oop(dstsys, syms) - return ReconstructInitializeprob(getter, setter) -end - -function (rip::ReconstructInitializeprob)(srcvalp, dstvalp) - newp = rip.setter(dstvalp, rip.getter(srcvalp)) - if state_values(dstvalp) === nothing - return nothing, newp - end - srcu0 = state_values(srcvalp) - T = srcu0 === nothing || isempty(srcu0) ? Union{} : eltype(srcu0) - if parameter_values(dstvalp) isa MTKParameters - if !isempty(newp.tunable) - T = promote_type(eltype(newp.tunable), T) - end - elseif !isempty(newp) - T = promote_type(eltype(newp), T) - end - if T == eltype(state_values(dstvalp)) - u0 = state_values(dstvalp) - elseif T != Union{} - u0 = T.(state_values(dstvalp)) - end - buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp) - if eltype(buf) != T - newbuf = similar(buf, T) - copyto!(newbuf, buf) - newp = repack(newbuf) - end - buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp) - if eltype(buf) != T - newbuf = similar(buf, T) - copyto!(newbuf, buf) - newp = repack(newbuf) - end - return u0, newp -end - -struct InitializationSystemMetadata - u0map::Dict{Any, Any} - pmap::Dict{Any, Any} - additional_guesses::Dict{Any, Any} - additional_initialization_eqs::Vector{Equation} - extra_metadata::NamedTuple - oop_reconstruct_u0_p::Union{Nothing, ReconstructInitializeprob} -end - function get_possibly_array_fallback_singletons(varmap, p) if haskey(varmap, p) return varmap[p] @@ -543,22 +481,19 @@ function SciMLBase.remake_initialization_data( if u0 === missing && p === missing return odefn.initialization_data end + + oldinitdata = odefn.initialization_data + if !(eltype(u0) <: Pair) && !(eltype(p) <: Pair) - oldinitdata = odefn.initialization_data oldinitdata === nothing && return nothing oldinitprob = oldinitdata.initializeprob oldinitprob === nothing && return nothing - if !SciMLBase.has_sys(oldinitprob.f) || !(oldinitprob.f.sys isa NonlinearSystem) - return oldinitdata - end - oldinitsys = oldinitprob.f.sys - meta = get_metadata(oldinitsys) - if meta isa InitializationSystemMetadata && meta.oop_reconstruct_u0_p !== nothing - reconstruct_fn = meta.oop_reconstruct_u0_p - else - reconstruct_fn = ReconstructInitializeprob(sys, oldinitsys) - end + + meta = oldinitdata.metadata + meta isa InitializationMetadata || return oldinitdata + + reconstruct_fn = meta.oop_reconstruct_u0_p # the history function doesn't matter because `reconstruct_fn` is only going to # update the values of parameters, which aren't time dependent. The reason it # is called is because `Initial` parameters are calculated from the corresponding @@ -569,16 +504,15 @@ function SciMLBase.remake_initialization_data( if oldinitprob.f.resid_prototype === nothing newf = oldinitprob.f else - newf = NonlinearFunction{ - SciMLBase.isinplace(oldinitprob.f), SciMLBase.specialization(oldinitprob.f)}( - oldinitprob.f; + newf = remake(oldinitprob.f; resid_prototype = calculate_resid_prototype( length(oldinitprob.f.resid_prototype), new_initu0, new_initp)) end initprob = remake(oldinitprob; f = newf, u0 = new_initu0, p = new_initp) return SciMLBase.OverrideInitData(initprob, oldinitdata.update_initializeprob!, - oldinitdata.initializeprobmap, oldinitdata.initializeprobpmap) + oldinitdata.initializeprobmap, oldinitdata.initializeprobpmap; metadata = oldinitdata.metadata) end + dvs = unknowns(sys) ps = parameters(sys) u0map = to_varmap(u0, dvs) @@ -592,16 +526,13 @@ function SciMLBase.remake_initialization_data( use_scc = true initialization_eqs = Equation[] - if SciMLBase.has_initializeprob(odefn) - oldsys = odefn.initialization_data.initializeprob.f.sys - meta = get_metadata(oldsys) - if meta isa InitializationSystemMetadata - u0map = merge(meta.u0map, u0map) - pmap = merge(meta.pmap, pmap) - merge!(guesses, meta.additional_guesses) - use_scc = get(meta.extra_metadata, :use_scc, true) - initialization_eqs = meta.additional_initialization_eqs - end + if oldinitdata !== nothing && oldinitdata.metadata isa InitializationMetadata + meta = oldinitdata.metadata + u0map = merge(meta.u0map, u0map) + pmap = merge(meta.pmap, pmap) + merge!(guesses, meta.guesses) + use_scc = meta.use_scc + initialization_eqs = meta.additional_initialization_eqs else # there is no initializeprob, so the original problem construction # had no solvable parameters and had the differential variables @@ -662,8 +593,11 @@ function SciMLBase.late_binding_update_u0_p( if !(eltype(u0) <: Pair) # if `p` is not provided or is symbolic p === missing || eltype(p) <: Pair || return newu0, newp - newu0 === nothing && return newu0, newp - all(is_parameter(sys, Initial(x)) for x in unknowns(sys)) || return newu0, newp + (newu0 === nothing || isempty(newu0)) && return newu0, newp + initdata = prob.f.initialization_data + initdata === nothing && return newu0, newp + meta = initdata.metadata + meta isa InitializationMetadata || return newu0, newp newp = p === missing ? copy(newp) : newp initials, repack, alias = SciMLStructures.canonicalize( SciMLStructures.Initials(), newp) @@ -671,10 +605,10 @@ function SciMLBase.late_binding_update_u0_p( initials = DiffEqBase.promote_u0(initials, newu0, t0) newp = repack(initials) end - if length(newu0) != length(unknowns(sys)) - throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(unknowns(sys)))). Got $(typeof(newu0)) of length $(length(newu0))")) + if length(newu0) != length(prob.u0) + throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))")) end - setp(sys, Initial.(unknowns(sys)))(newp, newu0) + meta.set_initial_unknowns!(newp, newu0) return newu0, newp end @@ -714,7 +648,7 @@ end Check if the given system is an initialization system. """ function is_initializesystem(sys::AbstractSystem) - sys isa NonlinearSystem && get_metadata(sys) isa InitializationSystemMetadata + has_is_initializesystem(sys) && get_is_initializesystem(sys) end """ diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 856822492b..61fd4ecbf1 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -87,6 +87,10 @@ struct NonlinearSystem <: AbstractTimeIndependentSystem """ gui_metadata::Union{Nothing, GUIMetadata} """ + Whether this is an initialization system. + """ + is_initializesystem::Bool + """ Cache for intermediate tearing state. """ tearing_state::Any @@ -116,6 +120,7 @@ struct NonlinearSystem <: AbstractTimeIndependentSystem tag, eqs, unknowns, ps, var_to_name, observed, jac, name, description, systems, defaults, guesses, initializesystem, initialization_eqs, connector_type, parameter_dependencies = Equation[], metadata = nothing, gui_metadata = nothing, + is_initializesystem = false, tearing_state = nothing, substitutions = nothing, namespacing = true, complete = false, index_cache = nothing, parent = nothing, isscheduled = false; checks::Union{Bool, Int} = true) @@ -126,7 +131,8 @@ struct NonlinearSystem <: AbstractTimeIndependentSystem end new(tag, eqs, unknowns, ps, var_to_name, observed, jac, name, description, systems, defaults, guesses, initializesystem, initialization_eqs, - connector_type, parameter_dependencies, metadata, gui_metadata, tearing_state, + connector_type, parameter_dependencies, metadata, gui_metadata, + is_initializesystem, tearing_state, substitutions, namespacing, complete, index_cache, parent, isscheduled) end end @@ -148,7 +154,8 @@ function NonlinearSystem(eqs, unknowns, ps; checks = true, parameter_dependencies = Equation[], metadata = nothing, - gui_metadata = nothing) + gui_metadata = nothing, + is_initializesystem = false) continuous_events === nothing || isempty(continuous_events) || throw(ArgumentError("NonlinearSystem does not accept `continuous_events`, you provided $continuous_events")) discrete_events === nothing || isempty(discrete_events) || @@ -196,7 +203,7 @@ function NonlinearSystem(eqs, unknowns, ps; NonlinearSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), eqs, dvs′, ps′, var_to_name, observed, jac, name, description, systems, defaults, guesses, initializesystem, initialization_eqs, connector_type, parameter_dependencies, - metadata, gui_metadata, checks = checks) + metadata, gui_metadata, is_initializesystem, checks = checks) end function NonlinearSystem(eqs; kwargs...) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 0750585905..81144213cc 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -620,6 +620,188 @@ function build_operating_point!(sys::AbstractSystem, return op, missing_unknowns, missing_pars end +""" + $(TYPEDEF) + +A callable struct used to reconstruct the `u0` and `p` of the initialization problem +with promoted types. + +# Fields + +$(TYPEDFIELDS) +""" +struct ReconstructInitializeprob{G} + """ + A function which when given the original problem and initialization problem, returns + the parameter object of the initialization problem with values copied from the + original. + """ + getter::G +end + +""" + $(TYPEDSIGNATURES) + +Given an index provider `indp` and a vector of symbols `syms` return a type-stable getter +function by splitting `syms` into contiguous buffers where the getter of each buffer +is type-stable and constructing a function that calls and concatenates the results. +""" +function concrete_getu(indp, syms::AbstractVector) + # a list of contiguous buffer + split_syms = [Any[syms[1]]] + # the type of the getter of the last buffer + current = typeof(getu(indp, syms[1])) + for sym in syms[2:end] + getter = getu(indp, sym) + if typeof(getter) != current + # if types don't match, build a new buffer + push!(split_syms, []) + current = typeof(getter) + end + push!(split_syms[end], sym) + end + split_syms = Tuple(split_syms) + # the getter is now type-stable, and we can vcat it to get the full buffer + return Base.Fix1(reduce, vcat) ∘ getu(indp, split_syms) +end + +""" + $(TYPEDSIGNATURES) + +Construct a `ReconstructInitializeprob` which reconstructs the `u0` and `p` of `dstsys` +with values from `srcsys`. +""" +function ReconstructInitializeprob( + srcsys::AbstractSystem, dstsys::AbstractSystem) + @assert is_initializesystem(dstsys) + if is_split(dstsys) + # if we call `getu` on this (and it were able to handle empty tuples) we get the + # fields of `MTKParameters` except caches. + syms = reorder_parameters(dstsys, parameters(dstsys); flatten = false) + # `dstsys` is an initialization system, do basically everything is a tunable + # and tunables are a mix of different types in `srcsys`. No initials. Constants + # are going to be constants in `srcsys`, as are `nonnumeric`. + + # `syms[1]` is always the tunables because `srcsys` will have initials. + tunable_syms = syms[1] + tunable_getter = concrete_getu(srcsys, tunable_syms) + rest_getters = map(Base.tail(Base.tail(syms))) do buf + if buf == () + return Returns(()) + else + return getu(srcsys, buf) + end + end + getters = (tunable_getter, Returns(SizedVector{0, Float64}()), rest_getters...) + getter = let getters = getters + function _getter(valp, initprob) + oldcache = parameter_values(initprob).caches + MTKParameters(getters[1](valp), getters[2](valp), getters[3](valp), + getters[4](valp), getters[5](valp), oldcache isa Tuple{} ? () : + copy.(oldcache)) + end + end + else + syms = parameters(dstsys) + getter = let inner = concrete_getu(srcsys, syms) + function _getter2(valp, initprob) + inner(valp) + end + end + end + return ReconstructInitializeprob(getter) +end + +""" + $(TYPEDSIGNATURES) + +Copy values from `srcvalp` to `dstvalp`. Returns the new `u0` and `p`. +""" +function (rip::ReconstructInitializeprob)(srcvalp, dstvalp) + # copy parameters + newp = rip.getter(srcvalp, dstvalp) + # no `u0`, so no type-promotion + if state_values(dstvalp) === nothing + return nothing, newp + end + # the `eltype` of the `u0` of the source + srcu0 = state_values(srcvalp) + T = srcu0 === nothing ? Union{} : eltype(srcu0) + # promote with the tunable eltype + if parameter_values(dstvalp) isa MTKParameters + if !isempty(newp.tunable) + T = promote_type(eltype(newp.tunable), T) + end + elseif !isempty(newp) + T = promote_type(eltype(newp), T) + end + # and the eltype of the destination u0 + if T == eltype(state_values(dstvalp)) + u0 = state_values(dstvalp) + elseif T != Union{} + u0 = T.(state_values(dstvalp)) + end + # apply the promotion to tunables portion + buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp) + if eltype(buf) != T + # only do a copy if the eltype doesn't match + newbuf = similar(buf, T) + copyto!(newbuf, buf) + newp = repack(newbuf) + end + # and initials portion + buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp) + if eltype(buf) != T + newbuf = similar(buf, T) + copyto!(newbuf, buf) + newp = repack(newbuf) + end + return u0, newp +end + +""" + $(TYPEDEF) + +Metadata attached to `OverrideInitData` used in `remake` hooks for handling initialization +properly. + +# Fields + +$(TYPEDFIELDS) +""" +struct InitializationMetadata{R <: ReconstructInitializeprob, SIU} + """ + The `u0map` used to construct the initialization. + """ + u0map::Dict{Any, Any} + """ + The `pmap` used to construct the initialization. + """ + pmap::Dict{Any, Any} + """ + The `guesses` used to construct the initialization. + """ + guesses::Dict{Any, Any} + """ + The `initialization_eqs` in addition to those of the system that were used to construct + the initialization. + """ + additional_initialization_eqs::Vector{Equation} + """ + Whether to use `SCCNonlinearProblem` if possible. + """ + use_scc::Bool + """ + `ReconstructInitializeprob` for this initialization problem. + """ + oop_reconstruct_u0_p::R + """ + A function which takes the `u0` of the problem and sets + `Initial.(unknowns(sys))`. + """ + set_initial_unknowns!::SIU +end + """ $(TYPEDSIGNATURES) @@ -632,8 +814,8 @@ All other keyword arguments are forwarded to `InitializationProblem`. """ function maybe_build_initialization_problem( sys::AbstractSystem, op::AbstractDict, u0map, pmap, t, defs, - guesses, missing_unknowns; implicit_dae = false, - u0_constructor = identity, floatT = Float64, kwargs...) + guesses, missing_unknowns; implicit_dae = false, u0_constructor = identity, + floatT = Float64, initialization_eqs = [], use_scc = true, kwargs...) guesses = merge(ModelingToolkit.guesses(sys), todict(guesses)) if t === nothing && is_time_dependent(sys) @@ -641,7 +823,7 @@ function maybe_build_initialization_problem( end initializeprob = ModelingToolkit.InitializationProblem{true, SciMLBase.FullSpecialize}( - sys, t, u0map, pmap; guesses, kwargs...) + sys, t, u0map, pmap; guesses, initialization_eqs, use_scc, kwargs...) if state_values(initializeprob) !== nothing initializeprob = remake(initializeprob; u0 = floatT.(state_values(initializeprob))) end @@ -658,7 +840,10 @@ function maybe_build_initialization_problem( end initializeprob = remake(initializeprob; p = initp) - meta = get_metadata(initializeprob.f.sys) + meta = InitializationMetadata( + u0map, pmap, guesses, Vector{Equation}(initialization_eqs), + use_scc, ReconstructInitializeprob(sys, initializeprob.f.sys), + setp(sys, Initial.(unknowns(sys)))) if is_time_dependent(sys) all_init_syms = Set(all_symbols(initializeprob)) @@ -710,7 +895,7 @@ function maybe_build_initialization_problem( return (; initialization_data = SciMLBase.OverrideInitData( initializeprob, update_initializeprob!, initializeprobmap, - initializeprobpmap)) + initializeprobpmap; metadata = meta)) end """ diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index 4ade3481cb..dbc69da672 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -1496,3 +1496,14 @@ end @test integ3.u ≈ [2.0, 3.0] @test integ3.ps[c1] ≈ 2.0 end + +# https://github.com/SciML/SciMLBase.jl/issues/985 +@testset "Type-stability of `remake`" begin + @parameters α=1 β=1 γ=1 δ=1 + @variables x(t)=1 y(t)=1 + eqs = [D(x) ~ α * x - β * x * y, D(y) ~ -δ * y + γ * x * y] + @named sys = ODESystem(eqs, t) + prob = ODEProblem(complete(sys), [], (0.0, 1)) + @inferred remake(prob; u0 = 2 .* prob.u0, p = prob.p) + @inferred solve(prob) +end