diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 2237ba8952..2f57bb1765 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -144,6 +144,7 @@ include("systems/abstractsystem.jl") include("systems/model_parsing.jl") include("systems/connectors.jl") include("systems/callbacks.jl") +include("systems/problem_utils.jl") include("systems/nonlinear/nonlinearsystem.jl") include("systems/diffeqs/odesystem.jl") diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index efec8b7fd8..b6e2a79c3e 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -2914,7 +2914,7 @@ function Base.eltype(::Type{<:TreeIterator{ModelingToolkit.AbstractSystem}}) end function check_array_equations_unknowns(eqs, dvs) - if any(eq -> Symbolics.isarraysymbolic(eq.lhs), eqs) + if any(eq -> eq isa Equation && Symbolics.isarraysymbolic(eq.lhs), eqs) throw(ArgumentError("The system has array equations. Call `structural_simplify` to handle such equations or scalarize them manually.")) end if any(x -> Symbolics.isarraysymbolic(x), dvs) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 8ebe6d93d0..e0d4c72f2b 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -702,302 +702,6 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys), !linenumbers ? Base.remove_linenums!(ex) : ex end -""" - u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=true, tofloat=true) - -Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point. -""" -function get_u0_p(sys, - u0map, - parammap = nothing; - t0 = nothing, - use_union = true, - tofloat = true, - symbolic_u0 = false) - dvs = unknowns(sys) - ps = parameters(sys) - - defs = defaults(sys) - if t0 !== nothing - defs[get_iv(sys)] = t0 - end - if parammap !== nothing - defs = mergedefaults(defs, parammap, ps) - end - if u0map isa Vector && eltype(u0map) <: Pair - u0map = Dict(u0map) - end - if u0map isa Dict - allobs = Set(getproperty.(observed(sys), :lhs)) - if any(in(allobs), keys(u0map)) - u0s_in_obs = filter(in(allobs), keys(u0map)) - @warn "Observed variables cannot be assigned initial values. Initial values for $u0s_in_obs will be ignored." - end - end - obs = filter!(x -> !(x[1] isa Number), map(x -> x.rhs => x.lhs, observed(sys))) - observedmap = isempty(obs) ? Dict() : todict(obs) - defs = mergedefaults(defs, observedmap, u0map, dvs) - for (k, v) in defs - if Symbolics.isarraysymbolic(k) - ks = scalarize(k) - length(ks) == length(v) || error("$k has default value $v with unmatched size") - for (kk, vv) in zip(ks, v) - if !haskey(defs, kk) - defs[kk] = vv - end - end - end - end - - if symbolic_u0 - u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false) - else - u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union) - end - p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union) - p = p === nothing ? SciMLBase.NullParameters() : p - t0 !== nothing && delete!(defs, get_iv(sys)) - u0, p, defs -end - -function get_u0( - sys, u0map, parammap = nothing; symbolic_u0 = false, - toterm = default_toterm, t0 = nothing, use_union = true) - dvs = unknowns(sys) - ps = parameters(sys) - defs = defaults(sys) - if t0 !== nothing - defs[get_iv(sys)] = t0 - end - if parammap !== nothing - defs = mergedefaults(defs, parammap, ps) - end - - # Convert observed equations "lhs ~ rhs" into defaults. - # Use the order "lhs => rhs" by default, but flip it to "rhs => lhs" - # if "lhs" is known by other means (parameter, another default, ...) - # TODO: Is there a better way to determine which equations to flip? - obs = map(x -> x.lhs => x.rhs, observed(sys)) - obs = map(x -> x[1] in keys(defs) ? reverse(x) : x, obs) - obs = filter!(x -> !(x[1] isa Number), obs) # exclude e.g. "0 => x^2 + y^2 - 25" - obsmap = isempty(obs) ? Dict() : todict(obs) - - defs = mergedefaults(defs, obsmap, u0map, dvs) - if symbolic_u0 - u0 = varmap_to_vars( - u0map, dvs; defaults = defs, tofloat = false, use_union = false, toterm) - else - u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union, toterm) - end - t0 !== nothing && delete!(defs, get_iv(sys)) - return u0, defs -end - -struct GetUpdatedMTKParameters{G, S} - # `getu` functor which gets parameters that are unknowns during initialization - getpunknowns::G - # `setu` functor which returns a modified MTKParameters using those parameters - setpunknowns::S -end - -function (f::GetUpdatedMTKParameters)(prob, initializesol) - mtkp = copy(parameter_values(prob)) - f.setpunknowns(mtkp, f.getpunknowns(initializesol)) - mtkp -end - -struct UpdateInitializeprob{G, S} - # `getu` functor which gets all values from prob - getvals::G - # `setu` functor which updates initializeprob with values - setvals::S -end - -function (f::UpdateInitializeprob)(initializeprob, prob) - f.setvals(initializeprob, f.getvals(prob)) -end - -function get_temporary_value(p) - stype = symtype(unwrap(p)) - return if stype == Real - zero(Float64) - elseif stype <: AbstractArray{Real} - zeros(Float64, size(p)) - elseif stype <: Real - zero(stype) - elseif stype <: AbstractArray - zeros(eltype(stype), size(p)) - else - error("Nonnumeric parameter $p with symtype $stype cannot be solved for during initialization") - end -end - -function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; - implicit_dae = false, du0map = nothing, - version = nothing, tgrad = false, - jac = false, - checkbounds = false, sparse = false, - simplify = false, - linenumbers = true, parallel = SerialForm(), - eval_expression = false, - eval_module = @__MODULE__, - use_union = false, - tofloat = true, - symbolic_u0 = false, - u0_constructor = identity, - guesses = Dict(), - t = nothing, - warn_initialize_determined = true, - build_initializeprob = true, - initialization_eqs = [], - fully_determined = false, - check_units = true, - kwargs...) - eqs = equations(sys) - dvs = unknowns(sys) - ps = parameters(sys) - iv = get_iv(sys) - - check_array_equations_unknowns(eqs, dvs) - # TODO: Pass already computed information to varmap_to_vars call - # in process_u0? That would just be a small optimization - varmap = u0map === nothing || isempty(u0map) || eltype(u0map) <: Number ? - defaults(sys) : - merge(defaults(sys), todict(u0map)) - varmap = canonicalize_varmap(varmap) - varlist = collect(map(unwrap, dvs)) - missingvars = setdiff(varlist, collect(keys(varmap))) - setobserved = filter(keys(varmap)) do var - has_observed_with_lhs(sys, var) || has_observed_with_lhs(sys, default_toterm(var)) - end - - if eltype(parammap) <: Pair - parammap = Dict{Any, Any}(unwrap(k) => v for (k, v) in parammap) - elseif parammap isa AbstractArray - if isempty(parammap) - parammap = SciMLBase.NullParameters() - else - parammap = Dict{Any, Any}(unwrap.(parameters(sys)) .=> parammap) - end - end - defs = defaults(sys) - if has_guesses(sys) - guesses = merge( - ModelingToolkit.guesses(sys), isempty(guesses) ? Dict() : todict(guesses)) - solvablepars = [p - for p in parameters(sys) - if is_parameter_solvable(p, parammap, defs, guesses)] - - pvarmap = if parammap === nothing || parammap == SciMLBase.NullParameters() || - !(eltype(parammap) <: Pair) && isempty(parammap) - defs - else - merge(defs, todict(parammap)) - end - setparobserved = filter(keys(pvarmap)) do var - has_parameter_dependency_with_lhs(sys, var) - end - else - solvablepars = () - setparobserved = () - end - # ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first - if sys isa ODESystem && build_initializeprob && - (((implicit_dae || !isempty(missingvars) || !isempty(solvablepars) || - !isempty(setobserved) || !isempty(setparobserved)) && - ModelingToolkit.get_tearing_state(sys) !== nothing) || - !isempty(initialization_equations(sys))) && t !== nothing - if eltype(u0map) <: Number - u0map = unknowns(sys) .=> vec(u0map) - end - if u0map === nothing || isempty(u0map) - u0map = Dict() - end - - initializeprob = ModelingToolkit.InitializationProblem( - sys, t, u0map, parammap; guesses, warn_initialize_determined, - initialization_eqs, eval_expression, eval_module, fully_determined, check_units) - initializeprobmap = getu(initializeprob, unknowns(sys)) - punknowns = [p - for p in all_variable_symbols(initializeprob) if is_parameter(sys, p)] - getpunknowns = getu(initializeprob, punknowns) - setpunknowns = setp(sys, punknowns) - initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns) - reqd_syms = parameter_symbols(initializeprob) - update_initializeprob! = UpdateInitializeprob( - getu(sys, reqd_syms), setu(initializeprob, reqd_syms)) - - zerovars = Dict(setdiff(unknowns(sys), keys(defaults(sys))) .=> 0.0) - if parammap isa SciMLBase.NullParameters - parammap = Dict() - end - for p in punknowns - p = unwrap(p) - stype = symtype(p) - parammap[p] = get_temporary_value(p) - end - trueinit = collect(merge(zerovars, eltype(u0map) <: Pair ? todict(u0map) : u0map)) - u0map isa StaticArraysCore.StaticArray && - (trueinit = SVector{length(trueinit)}(trueinit)) - else - initializeprob = nothing - update_initializeprob! = nothing - initializeprobmap = nothing - initializeprobpmap = nothing - trueinit = u0map - end - - if has_index_cache(sys) && get_index_cache(sys) !== nothing - u0, defs = get_u0(sys, trueinit, parammap; symbolic_u0, - t0 = constructor <: Union{DDEFunction, SDDEFunction} ? nothing : t, use_union) - check_eqs_u0(eqs, dvs, u0; kwargs...) - p = if parammap === nothing || - parammap == SciMLBase.NullParameters() && isempty(defs) - nothing - else - MTKParameters(sys, parammap, trueinit; t0 = t) - end - else - u0, p, defs = get_u0_p(sys, - trueinit, - parammap; - tofloat, - use_union, - t0 = constructor <: Union{DDEFunction, SDDEFunction} ? nothing : t, - symbolic_u0) - p, split_idxs = split_parameters_by_type(p) - if p isa Tuple - ps = Base.Fix1(getindex, parameters(sys)).(split_idxs) - ps = (ps...,) #if p is Tuple, ps should be Tuple - end - end - if u0 !== nothing - u0 = u0_constructor(u0) - end - - if implicit_dae && du0map !== nothing - ddvs = map(Differential(iv), dvs) - defs = mergedefaults(defs, du0map, ddvs) - du0 = varmap_to_vars(du0map, ddvs; defaults = defs, toterm = identity, - tofloat = true) - else - du0 = nothing - ddvs = nothing - end - check_eqs_u0(eqs, dvs, u0; kwargs...) - f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac, - checkbounds = checkbounds, p = p, - linenumbers = linenumbers, parallel = parallel, simplify = simplify, - sparse = sparse, eval_expression = eval_expression, - eval_module = eval_module, - initializeprob = initializeprob, - update_initializeprob! = update_initializeprob!, - initializeprobmap = initializeprobmap, - initializeprobpmap = initializeprobpmap, - kwargs...) - implicit_dae ? (f, du0, u0, p) : (f, u0, p) -end - function ODEFunctionExpr(sys::AbstractODESystem, args...; kwargs...) ODEFunctionExpr{true}(sys, args...; kwargs...) end @@ -1104,7 +808,7 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`") end - f, u0, p = process_DEProblem(ODEFunction{iip, specialize}, sys, u0map, parammap; + f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap; t = tspan !== nothing ? tspan[1] : tspan, check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...) cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...) @@ -1147,7 +851,7 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEProblem`") end - f, du0, u0, p = process_DEProblem(DAEFunction{iip}, sys, u0map, parammap; + f, du0, u0, p = process_SciMLProblem(DAEFunction{iip}, sys, u0map, parammap; implicit_dae = true, du0map = du0map, check_length, t = tspan !== nothing ? tspan[1] : tspan, warn_initialize_determined, kwargs...) @@ -1179,7 +883,7 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [], if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DDEProblem`") end - f, u0, p = process_DEProblem(DDEFunction{iip}, sys, u0map, parammap; + f, u0, p = process_SciMLProblem(DDEFunction{iip}, sys, u0map, parammap; t = tspan !== nothing ? tspan[1] : tspan, symbolic_u0 = true, check_length, eval_expression, eval_module, kwargs...) @@ -1214,7 +918,7 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [], if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `SDDEProblem`") end - f, u0, p = process_DEProblem(SDDEFunction{iip}, sys, u0map, parammap; + f, u0, p = process_SciMLProblem(SDDEFunction{iip}, sys, u0map, parammap; t = tspan !== nothing ? tspan[1] : tspan, symbolic_u0 = true, eval_expression, eval_module, check_length, kwargs...) @@ -1274,7 +978,8 @@ function ODEProblemExpr{iip}(sys::AbstractODESystem, u0map, tspan, if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `ODEProblemExpr`") end - f, u0, p = process_DEProblem(ODEFunctionExpr{iip}, sys, u0map, parammap; check_length, + f, u0, p = process_SciMLProblem( + ODEFunctionExpr{iip}, sys, u0map, parammap; check_length, t = tspan !== nothing ? tspan[1] : tspan, kwargs...) linenumbers = get(kwargs, :linenumbers, true) @@ -1320,7 +1025,7 @@ function DAEProblemExpr{iip}(sys::AbstractODESystem, du0map, u0map, tspan, if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEProblemExpr`") end - f, du0, u0, p = process_DEProblem(DAEFunctionExpr{iip}, sys, u0map, parammap; + f, du0, u0, p = process_SciMLProblem(DAEFunctionExpr{iip}, sys, u0map, parammap; t = tspan !== nothing ? tspan[1] : tspan, implicit_dae = true, du0map = du0map, check_length, kwargs...) @@ -1372,7 +1077,7 @@ function DiffEqBase.SteadyStateProblem{iip}(sys::AbstractODESystem, u0map, if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `SteadyStateProblem`") end - f, u0, p = process_DEProblem(ODEFunction{iip}, sys, u0map, parammap; + f, u0, p = process_SciMLProblem(ODEFunction{iip}, sys, u0map, parammap; steady_state = true, check_length, kwargs...) kwargs = filter_kwargs(kwargs) @@ -1404,7 +1109,7 @@ function SteadyStateProblemExpr{iip}(sys::AbstractODESystem, u0map, if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `SteadyStateProblemExpr`") end - f, u0, p = process_DEProblem(ODEFunctionExpr{iip}, sys, u0map, parammap; + f, u0, p = process_SciMLProblem(ODEFunctionExpr{iip}, sys, u0map, parammap; steady_state = true, check_length, kwargs...) linenumbers = get(kwargs, :linenumbers, true) diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index 1b368f7df3..0d73aaf313 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -659,7 +659,7 @@ function DiffEqBase.SDEProblem{iip, specialize}( if !iscomplete(sys) error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEProblem`") end - f, u0, p = process_DEProblem( + f, u0, p = process_SciMLProblem( SDEFunction{iip, specialize}, sys, u0map, parammap; check_length, kwargs...) cbs = process_events(sys; callback, kwargs...) @@ -745,7 +745,8 @@ function SDEProblemExpr{iip}(sys::SDESystem, u0map, tspan, if !iscomplete(sys) error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEProblemExpr`") end - f, u0, p = process_DEProblem(SDEFunctionExpr{iip}, sys, u0map, parammap; check_length, + f, u0, p = process_SciMLProblem( + SDEFunctionExpr{iip}, sys, u0map, parammap; check_length, kwargs...) linenumbers = get(kwargs, :linenumbers, true) sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false)) diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index 7103cfca80..bf7879be62 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -236,55 +236,25 @@ function generate_function( generate_custom_function(sys, exprs, dvs, ps; wrap_code, kwargs...) end -function process_DiscreteProblem(constructor, sys::DiscreteSystem, u0map, parammap; - linenumbers = true, parallel = SerialForm(), - use_union = false, - tofloat = !use_union, - eval_expression = false, eval_module = @__MODULE__, - kwargs...) +function shift_u0map_forward(sys::DiscreteSystem, u0map, defs) iv = get_iv(sys) - eqs = equations(sys) - dvs = unknowns(sys) - ps = parameters(sys) - - if eltype(u0map) <: Number - u0map = unknowns(sys) .=> vec(u0map) - end - if u0map === nothing || isempty(u0map) - u0map = Dict() - end - - trueu0map = Dict() - for (k, v) in u0map - k = unwrap(k) + updated = AnyDict() + for k in collect(keys(u0map)) + v = u0map[k] if !((op = operation(k)) isa Shift) error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).") end - trueu0map[Shift(iv, op.steps + 1)(arguments(k)[1])] = v - end - defs = ModelingToolkit.get_defaults(sys) - for var in dvs - if (op = operation(var)) isa Shift && !haskey(trueu0map, var) - root = arguments(var)[1] - haskey(defs, root) || error("Initial condition for $var not provided.") - trueu0map[var] = defs[root] - end + updated[Shift(iv, op.steps + 1)(arguments(k)[1])] = v end - if has_index_cache(sys) && get_index_cache(sys) !== nothing - u0, defs = get_u0(sys, trueu0map, parammap) - p = MTKParameters(sys, parammap, trueu0map) - else - u0, p, defs = get_u0_p(sys, trueu0map, parammap; tofloat, use_union) + for var in unknowns(sys) + op = operation(var) + op isa Shift || continue + haskey(updated, var) && continue + root = first(arguments(var)) + haskey(defs, root) || error("Initial condition for $var not provided.") + updated[var] = defs[root] end - - check_eqs_u0(eqs, dvs, u0; kwargs...) - - f = constructor(sys, dvs, ps, u0; - linenumbers = linenumbers, parallel = parallel, - syms = Symbol.(dvs), paramsyms = Symbol.(ps), - eval_expression = eval_expression, eval_module = eval_module, - kwargs...) - return f, u0, p + return updated end """ @@ -307,7 +277,9 @@ function SciMLBase.DiscreteProblem( eqs = equations(sys) iv = get_iv(sys) - f, u0, p = process_DiscreteProblem( + u0map = to_varmap(u0map, dvs) + u0map = shift_u0map_forward(sys, u0map, defaults(sys)) + f, u0, p = process_SciMLProblem( DiscreteFunction, sys, u0map, parammap; eval_expression, eval_module) u0 = f(u0, p, tspan[1]) DiscreteProblem(f, u0, tspan, p; kwargs...) diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index 60b061b1da..a714d4b364 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -348,20 +348,8 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, if !iscomplete(sys) error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`") end - dvs = unknowns(sys) - ps = parameters(sys) - - defs = defaults(sys) - defs = mergedefaults(defs, parammap, ps) - defs = mergedefaults(defs, u0map, dvs) - - u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false) - if has_index_cache(sys) && get_index_cache(sys) !== nothing - p = MTKParameters(sys, parammap, u0map) - else - p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union) - end - + _, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap; + t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false) f = DiffEqBase.DISCRETE_INPLACE_DEFAULT observedfun = ObservedFunctionCache(sys; eval_expression, eval_module) @@ -399,16 +387,9 @@ function DiscreteProblemExpr{iip}(sys::JumpSystem, u0map, tspan::Union{Tuple, No if !iscomplete(sys) error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblemExpr`") end - dvs = unknowns(sys) - ps = parameters(sys) - defs = defaults(sys) - u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false) - if has_index_cache(sys) && get_index_cache(sys) !== nothing - p = MTKParameters(sys, parammap, u0map) - else - p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union) - end + _, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap; + t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false) # identity function to make syms works quote f = DiffEqBase.DISCRETE_INPLACE_DEFAULT @@ -454,19 +435,9 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi if !iscomplete(sys) error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`") end - dvs = unknowns(sys) - ps = parameters(sys) - - defs = defaults(sys) - defs = mergedefaults(defs, parammap, ps) - defs = mergedefaults(defs, u0map, dvs) - u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false) - if has_index_cache(sys) && get_index_cache(sys) !== nothing - p = MTKParameters(sys, parammap, u0map) - else - p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union) - end + _, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap; + t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false) observedfun = ObservedFunctionCache(sys; eval_expression, eval_module) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 38a3095b5c..99e4e19d09 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -37,7 +37,7 @@ function generate_initializesystem(sys::ODESystem; # set dummy derivatives to default_dd_guess unless specified push!(defs, x[1] => get(guesses, x[1], default_dd_guess)) end - for (y, x) in u0map + function process_u0map_with_dummysubs(y, x) y = get(schedule.dummy_sub, y, y) y = fixpoint_sub(y, diffmap) if y ∈ vars_set @@ -53,6 +53,13 @@ function generate_initializesystem(sys::ODESystem; error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.") end end + for (y, x) in u0map + if Symbolics.isarraysymbolic(y) + process_u0map_with_dummysubs.(collect(y), collect(x)) + else + process_u0map_with_dummysubs(y, x) + end + end end # 2) process other variables diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 46bf032d6f..a0cd636753 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -293,7 +293,7 @@ function SciMLBase.NonlinearFunction(sys::NonlinearSystem, args...; kwargs...) end function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(sys), - ps = parameters(sys), u0 = nothing, p = nothing; + ps = parameters(sys), u0 = nothing; p = nothing, version = nothing, jac = false, eval_expression = false, @@ -408,36 +408,6 @@ function NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = unknowns(sys), !linenumbers ? Base.remove_linenums!(ex) : ex end -function process_NonlinearProblem(constructor, sys::NonlinearSystem, u0map, parammap; - version = nothing, - jac = false, - checkbounds = false, sparse = false, - simplify = false, - linenumbers = true, parallel = SerialForm(), - eval_expression = false, - eval_module = @__MODULE__, - use_union = false, - tofloat = !use_union, - kwargs...) - eqs = equations(sys) - dvs = unknowns(sys) - ps = parameters(sys) - if has_index_cache(sys) && get_index_cache(sys) !== nothing - u0, defs = get_u0(sys, u0map, parammap) - check_eqs_u0(eqs, dvs, u0; kwargs...) - p = MTKParameters(sys, parammap, u0map) - else - u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union) - check_eqs_u0(eqs, dvs, u0; kwargs...) - end - - f = constructor(sys, dvs, ps, u0, p; jac = jac, checkbounds = checkbounds, - linenumbers = linenumbers, parallel = parallel, simplify = simplify, - sparse = sparse, eval_expression = eval_expression, eval_module = eval_module, - kwargs...) - return f, u0, p -end - """ ```julia DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map, @@ -461,7 +431,7 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map, if !iscomplete(sys) error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearProblem`") end - f, u0, p = process_NonlinearProblem(NonlinearFunction{iip}, sys, u0map, parammap; + f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap; check_length, kwargs...) pt = something(get_metadata(sys), StandardNonlinearProblem()) NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...) @@ -490,7 +460,7 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma if !iscomplete(sys) error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearLeastSquaresProblem`") end - f, u0, p = process_NonlinearProblem(NonlinearFunction{iip}, sys, u0map, parammap; + f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap; check_length, kwargs...) pt = something(get_metadata(sys), StandardNonlinearProblem()) NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...) @@ -523,7 +493,7 @@ function NonlinearProblemExpr{iip}(sys::NonlinearSystem, u0map, if !iscomplete(sys) error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearProblemExpr`") end - f, u0, p = process_NonlinearProblem(NonlinearFunctionExpr{iip}, sys, u0map, parammap; + f, u0, p = process_SciMLProblem(NonlinearFunctionExpr{iip}, sys, u0map, parammap; check_length, kwargs...) linenumbers = get(kwargs, :linenumbers, true) @@ -563,7 +533,7 @@ function NonlinearLeastSquaresProblemExpr{iip}(sys::NonlinearSystem, u0map, if !iscomplete(sys) error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearProblemExpr`") end - f, u0, p = process_NonlinearProblem(NonlinearFunctionExpr{iip}, sys, u0map, parammap; + f, u0, p = process_SciMLProblem(NonlinearFunctionExpr{iip}, sys, u0map, parammap; check_length, kwargs...) linenumbers = get(kwargs, :linenumbers, true) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl new file mode 100644 index 0000000000..e530e62eed --- /dev/null +++ b/src/systems/problem_utils.jl @@ -0,0 +1,610 @@ +const AnyDict = Dict{Any, Any} + +""" + $(TYPEDSIGNATURES) + +If called without arguments, return `Dict{Any, Any}`. Otherwise, interpret the input +as a symbolic map and turn it into a `Dict{Any, Any}`. Handles `SciMLBase.NullParameters` +and `nothing`. +""" +anydict() = AnyDict() +anydict(::SciMLBase.NullParameters) = AnyDict() +anydict(::Nothing) = AnyDict() +anydict(x::AnyDict) = x +anydict(x) = AnyDict(x) + +""" + $(TYPEDSIGNATURES) + +Check if `x` is a symbolic with known size. Assumes `Symbolics.shape(unwrap(x))` +is a valid operation. +""" +is_sized_array_symbolic(x) = Symbolics.shape(unwrap(x)) != Symbolics.Unknown() + +""" + $(TYPEDSIGNATURES) + +Check if the system is in split form (has an `IndexCache`). +""" +is_split(sys::AbstractSystem) = has_index_cache(sys) && get_index_cache(sys) !== nothing + +""" + $(TYPEDSIGNATURES) + +Given a variable-value mapping, add mappings for the `toterm` of each of the keys. +""" +function add_toterms!(varmap::AbstractDict; toterm = default_toterm) + for k in collect(keys(varmap)) + varmap[toterm(k)] = varmap[k] + end + return nothing +end + +""" + $(TYPEDSIGNATURES) + +Out-of-place version of [`add_toterms!`](@ref). +""" +function add_toterms(varmap::AbstractDict; toterm = default_toterm) + cp = copy(varmap) + add_toterms!(cp; toterm) + return cp +end + +""" + $(TYPEDSIGNATURES) + +Ensure `varmap` contains entries for all variables in `vars` by using values from +`fallbacks` if they don't already exist in `varmap`. Return the set of all variables in +`vars` not present in `varmap` or `fallbacks`. If an array variable in `vars` does not +exist in `varmap` or `fallbacks`, each of its scalarized elements will be searched for. +In case none of the scalarized elements exist, the array variable will be reported as +missing. In case some of the scalarized elements exist, the missing elements will be +reported as missing. If `fallbacks` contains both the scalarized and non-scalarized forms, +the latter will take priority. + +Variables as they are specified in `vars` will take priority over their `toterm` forms. +""" +function add_fallbacks!( + varmap::AnyDict, vars::Vector, fallbacks::Dict; toterm = default_toterm) + missingvars = Set() + for var in vars + haskey(varmap, var) && continue + ttvar = toterm(var) + haskey(varmap, ttvar) && continue + + # array symbolics with a defined size may be present in the scalarized form + if Symbolics.isarraysymbolic(var) && is_sized_array_symbolic(var) + val = map(eachindex(var)) do idx + # @something is lazy and saves from writing a massive if-elseif-else + @something(get(varmap, var[idx], nothing), + get(varmap, ttvar[idx], nothing), get(fallbacks, var, nothing)[idx], + get(fallbacks, ttvar, nothing)[idx], get(fallbacks, var[idx], nothing), + get(fallbacks, ttvar[idx], nothing), Some(nothing)) + end + # only push the missing entries + mask = map(x -> x === nothing, val) + if all(mask) + push!(missingvars, var) + elseif any(mask) + for i in eachindex(var) + if mask[i] + push!(missingvars, var) + else + varmap[var[i]] = val[i] + end + end + else + varmap[var] = val + end + else + if iscall(var) && operation(var) == getindex + args = arguments(var) + arrvar = args[1] + ttarrvar = toterm(arrvar) + idxs = args[2:end] + val = @something get(varmap, arrvar, nothing) get(varmap, ttarrvar, nothing) get( + fallbacks, arrvar, nothing) get(fallbacks, ttarrvar, nothing) Some(nothing) + if val !== nothing + val = val[idxs...] + end + else + val = nothing + end + val = @something val get(fallbacks, var, nothing) get(fallbacks, ttvar, nothing) Some(nothing) + if val === nothing + push!(missingvars, var) + else + varmap[var] = val + end + end + end + + return missingvars +end + +""" + $(TYPEDSIGNATURES) + +Return the list of variables in `varlist` not present in `varmap`. Uses the same criteria +for missing array variables and `toterm` forms as [`add_fallbacks!`](@ref). +""" +function missingvars( + varmap::AbstractDict, varlist::Vector; toterm = default_toterm) + missingvars = Set() + for var in varlist + haskey(varmap, var) && continue + ttsym = toterm(var) + haskey(varmap, ttsym) && continue + + if Symbolics.isarraysymbolic(var) && is_sized_array_symbolic(var) + mask = map(eachindex(var)) do idx + !haskey(varmap, var[idx]) && !haskey(varmap, ttsym[idx]) + end + if all(mask) + push!(missingvars, var) + else + for i in eachindex(var) + mask[i] && push!(missingvars, var[i]) + end + end + else + push!(missingvars, var) + end + end + return missingvars +end + +""" + $(TYPEDSIGNATURES) + +Attempt to interpret `vals` as a symbolic map of variables in `varlist` to values. Return +the result as a `Dict{Any, Any}`. In case `vals` is already an iterable of pairs, convert +it to a `Dict{Any, Any}` and return. If `vals` is an array (whose `eltype` is not `Pair`) +with the same length as `varlist`, assume the `i`th element of `varlist` is mapped to the +`i`th element of `vals`. Automatically `unwrap`s all keys and values in the mapping. Also +handles `SciMLBase.NullParameters` and `nothing`, both of which are interpreted as empty +maps. +""" +function to_varmap(vals, varlist::Vector) + if vals isa AbstractArray && !(eltype(vals) <: Pair) && !isempty(vals) + check_eqs_u0(varlist, varlist, vals) + vals = vec(varlist) .=> vec(vals) + end + return anydict(unwrap(k) => unwrap(v) for (k, v) in anydict(vals)) +end + +""" + $(TYPEDSIGNATURES) + +Return the appropriate zero value for a symbolic variable representing a number or array of +numbers. Sized array symbolics return a zero-filled array of matching size. Unsized array +symbolics return an empty array of the appropriate `eltype`. +""" +function zero_var(x::Symbolic{T}) where {V <: Number, T <: Union{V, AbstractArray{V}}} + if Symbolics.isarraysymbolic(x) + if is_sized_array_symbolic(x) + return zeros(T, size(x)) + else + return T[] + end + else + return zero(T) + end +end + +""" + $(TYPEDSIGNATURES) + +Add equations `eqs` to `varmap`. Assumes each element in `eqs` maps a single symbolic +variable to an expression representing its value. In case `varmap` already contains an +entry for `eq.lhs`, insert the reverse mapping if `eq.rhs` is not a number. +""" +function add_observed_equations!(varmap::AbstractDict, eqs) + for eq in eqs + if haskey(varmap, eq.lhs) + eq.rhs isa Number && continue + haskey(varmap, eq.rhs) && continue + !iscall(eq.rhs) || issym(operation(eq.rhs)) || continue + varmap[eq.rhs] = eq.lhs + else + varmap[eq.lhs] = eq.rhs + end + end +end + +""" + $(TYPEDSIGNATURES) + +Add all equations in `observed(sys)` to `varmap` using [`add_observed_equations!`](@ref). +""" +function add_observed!(sys::AbstractSystem, varmap::AbstractDict) + add_observed_equations!(varmap, observed(sys)) +end + +""" + $(TYPEDSIGNATURES) + +Add all equations in `parameter_dependencies(sys)` to `varmap` using +[`add_observed_equations!`](@ref). +""" +function add_parameter_dependencies!(sys::AbstractSystem, varmap::AbstractDict) + has_parameter_dependencies(sys) || return nothing + add_observed_equations!(varmap, parameter_dependencies(sys)) +end + +""" + $(TYPEDSIGNATURES) + +Return an array of values where the `i`th element corresponds to the value of `vars[i]` +in `varmap`. Does not perform symbolic substitution in the values of `varmap`. + +Keyword arguments: +- `tofloat`: Convert values to floating point numbers using `float`. +- `use_union`: Use a `Union`-typed array if the values have heterogeneous types. +- `container_type`: The type of container to use for the values. +- `toterm`: The `toterm` method to use for converting symbolics. +- `promotetoconcrete`: whether the promote to a concrete buffer (respecting + `tofloat` and `use_union`). Defaults to `container_type <: AbstractArray`. +- `check`: Error if any variables in `vars` do not have a mapping in `varmap`. Uses + [`missingvars`](@ref) to perform the check. +- `allow_symbolic` allows the returned array to contain symbolic values. If this is `true`, + `promotetoconcrete` is set to `false`. +""" +function better_varmap_to_vars(varmap::AbstractDict, vars::Vector; + tofloat = true, use_union = true, container_type = Array, + toterm = default_toterm, promotetoconcrete = nothing, check = true, allow_symbolic = false) + isempty(vars) && return nothing + + if check + missing_vars = missingvars(varmap, vars; toterm) + isempty(missing_vars) || throw(MissingVariablesError(missing_vars)) + end + vals = map(x -> varmap[x], vars) + + if container_type <: Union{AbstractDict, Tuple, Nothing} + container_type = Array + end + + promotetoconcrete === nothing && (promotetoconcrete = container_type <: AbstractArray) + if promotetoconcrete && !allow_symbolic + vals = promote_to_concrete(vals; tofloat = tofloat, use_union = use_union) + end + + if isempty(vals) + return nothing + elseif container_type <: Tuple + return (vals...,) + else + return SymbolicUtils.Code.create_array(container_type, eltype(vals), Val{1}(), + Val(length(vals)), vals...) + end +end + +""" + $(TYPEDSIGNATURES) + +Performs symbolic substitution on the values in `varmap`, using `varmap` itself as the +set of substitution rules. +""" +function evaluate_varmap!(varmap::AbstractDict) + for (k, v) in varmap + varmap[k] = fixpoint_sub(v, varmap) + end +end + +struct GetUpdatedMTKParameters{G, S} + # `getu` functor which gets parameters that are unknowns during initialization + getpunknowns::G + # `setu` functor which returns a modified MTKParameters using those parameters + setpunknowns::S +end + +function (f::GetUpdatedMTKParameters)(prob, initializesol) + mtkp = copy(parameter_values(prob)) + f.setpunknowns(mtkp, f.getpunknowns(initializesol)) + mtkp +end + +struct UpdateInitializeprob{G, S} + # `getu` functor which gets all values from prob + getvals::G + # `setu` functor which updates initializeprob with values + setvals::S +end + +function (f::UpdateInitializeprob)(initializeprob, prob) + f.setvals(initializeprob, f.getvals(prob)) +end + +function get_temporary_value(p) + stype = symtype(unwrap(p)) + return if stype == Real + zero(Float64) + elseif stype <: AbstractArray{Real} + zeros(Float64, size(p)) + elseif stype <: Real + zero(stype) + elseif stype <: AbstractArray + zeros(eltype(stype), size(p)) + else + error("Nonnumeric parameter $p with symtype $stype cannot be solved for during initialization") + end +end + +""" + $(TYPEDEF) + +A simple utility meant to be used as the `constructor` passed to `process_SciMLProblem` in +case constructing a SciMLFunction is not required. +""" +struct EmptySciMLFunction end + +function EmptySciMLFunction(args...; kwargs...) + return nothing +end + +""" + $(TYPEDSIGNATURES) + +Return the SciMLFunction created via calling `constructor`, the initial conditions `u0` +and parameter object `p` given the system `sys`, and user-provided initial values `u0map` +and `pmap`. `u0map` and `pmap` are converted into variable maps via [`to_varmap`](@ref). + +The order of unknowns is determined by `unknowns(sys)`. If the system is split +[`is_split`](@ref) create an [`MTKParameters`](@ref) object. Otherwise, a parameter vector. +Initial values provided in terms of other variables will be symbolically evaluated using +[`evaluate_varmap!`](@ref). The type of `u0map` and `pmap` will be used to determine the +type of the containers (if parameters are not in an `MTKParameters` object). `Dict`s will be +turned into `Array`s. + +If `sys isa ODESystem`, this will also build the initialization problem and related objects +and pass them to the SciMLFunction as keyword arguments. + +Keyword arguments: +- `build_initializeprob`: If `false`, avoids building the initialization problem for an + `ODESystem`. +- `t`: The initial time of the `ODEProblem`. If this is not provided, the initialization + problem cannot be built. +- `implicit_dae`: Also build a mapping of derivatives of states to values for implicit DAEs, + using `du0map`. Changes the return value of this function to `(f, du0, u0, p)` instead of + `(f, u0, p)`. +- `guesses`: The guesses for variables in the system, used as initial values for the + initialization problem. +- `warn_initialize_determined`: Warn if the initialization system is under/over-determined. +- `initialization_eqs`: Extra equations to use in the initialization problem. +- `eval_expression`: Whether to compile any functions via `eval` or `RuntimeGeneratedFunctions`. +- `eval_module`: If `eval_expression == true`, the module to `eval` into. Otherwise, the module + in which to generate the `RuntimeGeneratedFunction`. +- `fully_determined`: Override whether the initialization system is fully determined. +- `check_units`: Enable or disable unit checks. +- `tofloat`, `use_union`: Passed to [`better_varmap_to_vars`](@ref) for building `u0` (and + possibly `p`). +- `u0_constructor`: A function to apply to the `u0` value returned from `better_varmap_to_vars` + to construct the final `u0` value. +- `du0map`: A map of derivatives to values. See `implicit_dae`. +- `check_length`: Whether to check the number of equations along with number of unknowns and + length of `u0` vector for consistency. If `false`, do not check with equations. This is + forwarded to `check_eqs_u0` +- `symbolic_u0` allows the returned `u0` to be an array of symbolics. + +All other keyword arguments are passed as-is to `constructor`. +""" +function process_SciMLProblem( + constructor, sys::AbstractSystem, u0map, pmap; build_initializeprob = true, + implicit_dae = false, t = nothing, guesses = AnyDict(), + warn_initialize_determined = true, initialization_eqs = [], + eval_expression = false, eval_module = @__MODULE__, fully_determined = false, + check_units = true, tofloat = true, use_union = false, + u0_constructor = identity, du0map = nothing, check_length = true, symbolic_u0 = false, kwargs...) + dvs = unknowns(sys) + ps = parameters(sys) + iv = has_iv(sys) ? get_iv(sys) : nothing + eqs = equations(sys) + + check_array_equations_unknowns(eqs, dvs) + + u0Type = typeof(u0map) + pType = typeof(pmap) + _u0map = u0map + u0map = to_varmap(u0map, dvs) + _pmap = pmap + pmap = to_varmap(pmap, ps) + defs = add_toterms(defaults(sys)) + cmap, cs = get_cmap(sys) + kwargs = NamedTuple(kwargs) + + op = add_toterms(u0map) + missing_unknowns = add_fallbacks!(op, dvs, defs) + for (k, v) in defs + haskey(op, k) && continue + op[k] = v + end + merge!(op, pmap) + missing_pars = add_fallbacks!(op, ps, defs) + for eq in cmap + op[eq.lhs] = eq.rhs + end + if sys isa ODESystem + guesses = merge(ModelingToolkit.guesses(sys), todict(guesses)) + has_observed_u0s = any( + k -> has_observed_with_lhs(sys, k) || has_parameter_dependency_with_lhs(sys, k), + keys(op)) + solvablepars = [p + for p in parameters(sys) + if is_parameter_solvable(p, pmap, defs, guesses)] + if build_initializeprob && + (((implicit_dae || has_observed_u0s || !isempty(missing_unknowns) || + !isempty(solvablepars)) && + get_tearing_state(sys) !== nothing) || + !isempty(initialization_equations(sys))) && t !== nothing + initializeprob = ModelingToolkit.InitializationProblem( + sys, t, u0map, pmap; guesses, warn_initialize_determined, + initialization_eqs, eval_expression, eval_module, fully_determined, check_units) + initializeprobmap = getu(initializeprob, unknowns(sys)) + + punknowns = [p + for p in all_variable_symbols(initializeprob) + if is_parameter(sys, p)] + getpunknowns = getu(initializeprob, punknowns) + setpunknowns = setp(sys, punknowns) + initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns) + + reqd_syms = parameter_symbols(initializeprob) + update_initializeprob! = UpdateInitializeprob( + getu(sys, reqd_syms), setu(initializeprob, reqd_syms)) + for p in punknowns + p = unwrap(p) + stype = symtype(p) + op[p] = get_temporary_value(p) + delete!(missing_pars, p) + end + + for v in missing_unknowns + op[v] = zero_var(v) + end + empty!(missing_unknowns) + kwargs = merge(kwargs, + (; initializeprob, initializeprobmap, + initializeprobpmap, update_initializeprob!)) + end + end + + if t !== nothing && !(constructor <: Union{DDEFunction, SDDEFunction}) + op[iv] = t + end + + add_observed!(sys, op) + add_parameter_dependencies!(sys, op) + + evaluate_varmap!(op) + + u0 = better_varmap_to_vars( + op, dvs; tofloat = true, use_union = false, + container_type = u0Type, allow_symbolic = symbolic_u0) + + if u0 !== nothing + u0 = u0_constructor(u0) + end + + check_eqs_u0(eqs, dvs, u0; check_length, kwargs...) + + if is_split(sys) + p = MTKParameters(sys, op) + else + p = better_varmap_to_vars(op, ps; tofloat, use_union, container_type = pType) + end + + if implicit_dae && du0map !== nothing + ddvs = map(Differential(iv), dvs) + du0map = to_varmap(du0map, ddvs) + merge!(op, du0map) + + du0 = varmap_to_vars(du0map, ddvs; toterm = identity, + tofloat = true) + kwargs = merge(kwargs, (; ddvs)) + else + du0 = nothing + end + + f = constructor(sys, dvs, ps, u0; p = p, + eval_expression = eval_expression, + eval_module = eval_module, + kwargs...) + implicit_dae ? (f, du0, u0, p) : (f, u0, p) +end + +############## +# Legacy functions for backward compatibility +############## + +""" + u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=true, tofloat=true) + +Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point. +""" +function get_u0_p(sys, + u0map, + parammap = nothing; + t0 = nothing, + use_union = true, + tofloat = true, + symbolic_u0 = false) + dvs = unknowns(sys) + ps = parameters(sys) + + defs = defaults(sys) + if t0 !== nothing + defs[get_iv(sys)] = t0 + end + if parammap !== nothing + defs = mergedefaults(defs, parammap, ps) + end + if u0map isa Vector && eltype(u0map) <: Pair + u0map = Dict(u0map) + end + if u0map isa Dict + allobs = Set(getproperty.(observed(sys), :lhs)) + if any(in(allobs), keys(u0map)) + u0s_in_obs = filter(in(allobs), keys(u0map)) + @warn "Observed variables cannot be assigned initial values. Initial values for $u0s_in_obs will be ignored." + end + end + obs = filter!(x -> !(x[1] isa Number), map(x -> x.rhs => x.lhs, observed(sys))) + observedmap = isempty(obs) ? Dict() : todict(obs) + defs = mergedefaults(defs, observedmap, u0map, dvs) + for (k, v) in defs + if Symbolics.isarraysymbolic(k) + ks = scalarize(k) + length(ks) == length(v) || error("$k has default value $v with unmatched size") + for (kk, vv) in zip(ks, v) + if !haskey(defs, kk) + defs[kk] = vv + end + end + end + end + + if symbolic_u0 + u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false) + else + u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union) + end + p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union) + p = p === nothing ? SciMLBase.NullParameters() : p + t0 !== nothing && delete!(defs, get_iv(sys)) + u0, p, defs +end + +function get_u0( + sys, u0map, parammap = nothing; symbolic_u0 = false, + toterm = default_toterm, t0 = nothing, use_union = true) + dvs = unknowns(sys) + ps = parameters(sys) + defs = defaults(sys) + if t0 !== nothing + defs[get_iv(sys)] = t0 + end + if parammap !== nothing + defs = mergedefaults(defs, parammap, ps) + end + + # Convert observed equations "lhs ~ rhs" into defaults. + # Use the order "lhs => rhs" by default, but flip it to "rhs => lhs" + # if "lhs" is known by other means (parameter, another default, ...) + # TODO: Is there a better way to determine which equations to flip? + obs = map(x -> x.lhs => x.rhs, observed(sys)) + obs = map(x -> x[1] in keys(defs) ? reverse(x) : x, obs) + obs = filter!(x -> !(x[1] isa Number), obs) # exclude e.g. "0 => x^2 + y^2 - 25" + obsmap = isempty(obs) ? Dict() : todict(obs) + + defs = mergedefaults(defs, obsmap, u0map, dvs) + if symbolic_u0 + u0 = varmap_to_vars( + u0map, dvs; defaults = defs, tofloat = false, use_union = false, toterm) + else + u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union, toterm) + end + t0 !== nothing && delete!(defs, get_iv(sys)) + return u0, defs +end diff --git a/src/utils.jl b/src/utils.jl index e8ed131d78..1e54e7047b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -610,6 +610,15 @@ function collect_constants!(constants, expr::Symbolic) end end +function collect_constants!(constants, expr::Union{ConstantRateJump, VariableRateJump}) + collect_constants!(constants, expr.rate) + collect_constants!(constants, expr.affect!) +end + +function collect_constants!(constants, ::MassActionJump) + return constants +end + """ Replace symbolic constants with their literal values """ @@ -667,7 +676,7 @@ end function get_cmap(sys, exprs = nothing) #Inject substitutions for constants => values - cs = collect_constants([get_eqs(sys); get_observed(sys)]) #ctrls? what else? + cs = collect_constants([collect(get_eqs(sys)); get_observed(sys)]) #ctrls? what else? if !empty_substitutions(sys) cs = [cs; collect_constants(get_substitutions(sys).subs)] end diff --git a/src/variables.jl b/src/variables.jl index 1c22540c63..c0c875450c 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -140,7 +140,7 @@ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true, if is_incomplete_initialization || isempty(varmap) if isempty(defaults) if !is_incomplete_initialization && check - isempty(varlist) || throw_missingvars(varlist) + isempty(varlist) || throw(MissingVariablesError(varlist)) end return nothing else diff --git a/test/constants.jl b/test/constants.jl index bfdc83bafc..f2c4fdaa86 100644 --- a/test/constants.jl +++ b/test/constants.jl @@ -37,3 +37,16 @@ eqs = [D(x) ~ β] simp = structural_simplify(sys) @test isempty(MT.collect_constants(nothing)) + +@testset "Issue#3044" begin + @constants h = 1 + @parameters τ = 0.5 * h + @variables x(MT.t_nounits) = h + eqs = [MT.D_nounits(x) ~ (h - x) / τ] + + @mtkbuild fol_model = ODESystem(eqs, MT.t_nounits) + + prob = ODEProblem(fol_model, [], (0.0, 10.0)) + @test prob[x] ≈ 1 + @test prob.ps[τ] ≈ 0.5 +end diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index 08b653c22e..e819980752 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -487,7 +487,8 @@ sys = extend(sysx, sysy) @variables x(t) y(t) @named sys = ODESystem([x^2 + y^2 ~ 25, D(x) ~ 1], t) ssys = structural_simplify(sys) - @test_throws ArgumentError ODEProblem(ssys, [x => 3], (0, 1), []) # y should have a guess + @test_throws ModelingToolkit.MissingVariablesError ODEProblem( + ssys, [x => 3], (0, 1), []) # y should have a guess end # https://github.com/SciML/ModelingToolkit.jl/issues/3025 diff --git a/test/sdesystem.jl b/test/sdesystem.jl index cae9ec9376..c258a4142b 100644 --- a/test/sdesystem.jl +++ b/test/sdesystem.jl @@ -36,10 +36,6 @@ solexpr = solve(eval(probexpr), SRIW1(), seed = 1) @test all(x -> x == 0, Array(sol - solexpr)) -# Test no error -@test_nowarn SDEProblem(de, nothing, (0, 10.0)) -@test SDEProblem(de, nothing).tspan == (0.0, 10.0) - noiseeqs_nd = [0.01*x 0.01*x*y 0.02*x*z σ 0.01*y 0.02*x*z ρ β 0.01*z] diff --git a/test/split_parameters.jl b/test/split_parameters.jl index b8651238ea..22c90edf7a 100644 --- a/test/split_parameters.jl +++ b/test/split_parameters.jl @@ -125,7 +125,7 @@ sol = solve(prob, ImplicitEuler()); prob = ODEProblem( sys, [], tspan, []; tofloat = false, use_union = true, build_initializeprob = false) -@test prob.p isa Tuple{Vector{Float64}, Vector{Int64}} +@test prob.p isa Vector{Union{Float64, Int64}} sol = solve(prob, ImplicitEuler()); @test sol.retcode == ReturnCode.Success