Skip to content

Commit d0278cc

Browse files
refactor: remove get_u0_p, modernize get_u0 and add get_p
1 parent 41208b2 commit d0278cc

File tree

1 file changed

+29
-81
lines changed

1 file changed

+29
-81
lines changed

src/systems/problem_utils.jl

Lines changed: 29 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1700,97 +1700,45 @@ function maybe_codegen_scimlproblem(::Type{Val{false}}, T, args::NamedTuple; kwa
17001700
remake(T(args...; kwargs...))
17011701
end
17021702

1703-
##############
1704-
# Legacy functions for backward compatibility
1705-
##############
1706-
17071703
"""
1708-
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=true, tofloat=true)
1704+
$(TYPEDSIGNATURES)
17091705
1710-
Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
1706+
Return the `u0` vector for the given system `sys` and variable-value mapping `varmap`. All
1707+
keyword arguments are forwarded to [`varmap_to_vars`](@ref).
17111708
"""
1712-
function get_u0_p(sys,
1713-
u0map,
1714-
parammap = nothing;
1715-
t0 = nothing,
1716-
tofloat = true,
1717-
use_union = true,
1718-
symbolic_u0 = false)
1709+
function get_u0(sys::AbstractSystem, varmap; kwargs...)
17191710
dvs = unknowns(sys)
17201711
ps = parameters(sys; initial_parameters = true)
1712+
op = to_varmap(varmap, dvs)
1713+
add_observed!(sys, op)
1714+
add_parameter_dependencies!(sys, op)
1715+
missing_dvs, _ = build_operating_point!(sys, op, Dict(), defaults(sys), dvs, ps)
17211716

1722-
defs = defaults(sys)
1723-
if t0 !== nothing
1724-
defs[get_iv(sys)] = t0
1725-
end
1726-
if parammap !== nothing
1727-
defs = mergedefaults(defs, parammap, ps)
1728-
end
1729-
if u0map isa Vector && eltype(u0map) <: Pair
1730-
u0map = Dict(u0map)
1731-
end
1732-
if u0map isa Dict
1733-
allobs = Set(observables(sys))
1734-
if any(in(allobs), keys(u0map))
1735-
u0s_in_obs = filter(in(allobs), keys(u0map))
1736-
@warn "Observed variables cannot be assigned initial values. Initial values for $u0s_in_obs will be ignored."
1737-
end
1738-
end
1739-
obs = filter!(x -> !(x[1] isa Number), map(x -> x.rhs => x.lhs, observed(sys)))
1740-
observedmap = isempty(obs) ? Dict() : todict(obs)
1741-
defs = mergedefaults(defs, observedmap, u0map, dvs)
1742-
for (k, v) in defs
1743-
if Symbolics.isarraysymbolic(k)
1744-
ks = scalarize(k)
1745-
length(ks) == length(v) || error("$k has default value $v with unmatched size")
1746-
for (kk, vv) in zip(ks, v)
1747-
if !haskey(defs, kk)
1748-
defs[kk] = vv
1749-
end
1750-
end
1751-
end
1752-
end
1717+
isempty(missing_dvs) || throw(MissingVariablesError(collect(missing_dvs)))
17531718

1754-
if symbolic_u0
1755-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
1756-
else
1757-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat, use_union)
1758-
end
1759-
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
1760-
p = p === nothing ? SciMLBase.NullParameters() : p
1761-
t0 !== nothing && delete!(defs, get_iv(sys))
1762-
u0, p, defs
1719+
return varmap_to_vars(op, dvs; kwargs...)
17631720
end
17641721

1765-
function get_u0(
1766-
sys, u0map, parammap = nothing; symbolic_u0 = false,
1767-
toterm = default_toterm, t0 = nothing, use_union = true)
1722+
"""
1723+
$(TYPEDSIGNATURES)
1724+
1725+
Return the `u0` vector for the given system `sys` and variable-value mapping `varmap`. All
1726+
keyword arguments are forwarded to [`MTKParameters`](@ref) for split systems and
1727+
[`varmap_to_vars`](@ref) for non-split systems.
1728+
"""
1729+
function get_p(sys::AbstractSystem, varmap; kwargs...)
17681730
dvs = unknowns(sys)
1769-
ps = parameters(sys)
1770-
defs = defaults(sys)
1771-
if t0 !== nothing
1772-
defs[get_iv(sys)] = t0
1773-
end
1774-
if parammap !== nothing
1775-
defs = mergedefaults(defs, parammap, ps)
1776-
end
1777-
1778-
# Convert observed equations "lhs ~ rhs" into defaults.
1779-
# Use the order "lhs => rhs" by default, but flip it to "rhs => lhs"
1780-
# if "lhs" is known by other means (parameter, another default, ...)
1781-
# TODO: Is there a better way to determine which equations to flip?
1782-
obs = map(x -> x.lhs => x.rhs, observed(sys))
1783-
obs = map(x -> x[1] in keys(defs) ? reverse(x) : x, obs)
1784-
obs = filter!(x -> !(x[1] isa Number), obs) # exclude e.g. "0 => x^2 + y^2 - 25"
1785-
obsmap = isempty(obs) ? Dict() : todict(obs)
1786-
1787-
defs = mergedefaults(defs, obsmap, u0map, dvs)
1788-
if symbolic_u0
1789-
u0 = varmap_to_vars(
1790-
u0map, dvs; defaults = defs, tofloat = false, use_union = false, toterm)
1731+
ps = parameters(sys; initial_parameters = true)
1732+
op = to_varmap(varmap, dvs)
1733+
add_observed!(sys, op)
1734+
add_parameter_dependencies!(sys, op)
1735+
_, missing_ps = build_operating_point!(sys, op, Dict(), defaults(sys), dvs, ps)
1736+
1737+
isempty(missing_ps) || throw(MissingParametersError(collect(missing_ps)))
1738+
1739+
if is_split(sys)
1740+
MTKParameters(sys, op; kwargs...)
17911741
else
1792-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union, toterm)
1742+
varmap_to_vars(op, ps; kwargs...)
17931743
end
1794-
t0 !== nothing && delete!(defs, get_iv(sys))
1795-
return u0, defs
17961744
end

0 commit comments

Comments
 (0)