Skip to content

Commit c5aa808

Browse files
refactor: remove get_u0_p, modernize get_u0 and add get_p
1 parent c07372d commit c5aa808

File tree

1 file changed

+31
-81
lines changed

1 file changed

+31
-81
lines changed

src/systems/problem_utils.jl

Lines changed: 31 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1698,97 +1698,47 @@ function maybe_codegen_scimlproblem(::Type{Val{false}}, T, args::NamedTuple; kwa
16981698
remake(T(args...; kwargs...))
16991699
end
17001700

1701-
##############
1702-
# Legacy functions for backward compatibility
1703-
##############
1704-
17051701
"""
1706-
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=true, tofloat=true)
1702+
$(TYPEDSIGNATURES)
17071703
1708-
Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
1704+
Return the `u0` vector for the given system `sys` and variable-value mapping `varmap`. All
1705+
keyword arguments are forwarded to [`varmap_to_vars`](@ref).
17091706
"""
1710-
function get_u0_p(sys,
1711-
u0map,
1712-
parammap = nothing;
1713-
t0 = nothing,
1714-
tofloat = true,
1715-
use_union = true,
1716-
symbolic_u0 = false)
1707+
function get_u0(sys::AbstractSystem, varmap; kwargs...)
17171708
dvs = unknowns(sys)
17181709
ps = parameters(sys; initial_parameters = true)
1710+
op = to_varmap(varmap, dvs)
1711+
add_observed!(sys, op)
1712+
add_parameter_dependencies!(sys, op)
1713+
missing_dvs, _ = build_operating_point!(
1714+
sys, op, Dict(), Dict(), defaults(sys), dvs, ps)
17191715

1720-
defs = defaults(sys)
1721-
if t0 !== nothing
1722-
defs[get_iv(sys)] = t0
1723-
end
1724-
if parammap !== nothing
1725-
defs = mergedefaults(defs, parammap, ps)
1726-
end
1727-
if u0map isa Vector && eltype(u0map) <: Pair
1728-
u0map = Dict(u0map)
1729-
end
1730-
if u0map isa Dict
1731-
allobs = Set(observables(sys))
1732-
if any(in(allobs), keys(u0map))
1733-
u0s_in_obs = filter(in(allobs), keys(u0map))
1734-
@warn "Observed variables cannot be assigned initial values. Initial values for $u0s_in_obs will be ignored."
1735-
end
1736-
end
1737-
obs = filter!(x -> !(x[1] isa Number), map(x -> x.rhs => x.lhs, observed(sys)))
1738-
observedmap = isempty(obs) ? Dict() : todict(obs)
1739-
defs = mergedefaults(defs, observedmap, u0map, dvs)
1740-
for (k, v) in defs
1741-
if Symbolics.isarraysymbolic(k)
1742-
ks = scalarize(k)
1743-
length(ks) == length(v) || error("$k has default value $v with unmatched size")
1744-
for (kk, vv) in zip(ks, v)
1745-
if !haskey(defs, kk)
1746-
defs[kk] = vv
1747-
end
1748-
end
1749-
end
1750-
end
1716+
isempty(missing_dvs) || throw(MissingVariablesError(collect(missing_dvs)))
17511717

1752-
if symbolic_u0
1753-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
1754-
else
1755-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat, use_union)
1756-
end
1757-
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
1758-
p = p === nothing ? SciMLBase.NullParameters() : p
1759-
t0 !== nothing && delete!(defs, get_iv(sys))
1760-
u0, p, defs
1718+
return varmap_to_vars(op, dvs; kwargs...)
17611719
end
17621720

1763-
function get_u0(
1764-
sys, u0map, parammap = nothing; symbolic_u0 = false,
1765-
toterm = default_toterm, t0 = nothing, use_union = true)
1721+
"""
1722+
$(TYPEDSIGNATURES)
1723+
1724+
Return the `u0` vector for the given system `sys` and variable-value mapping `varmap`. All
1725+
keyword arguments are forwarded to [`MTKParameters`](@ref) for split systems and
1726+
[`varmap_to_vars`](@ref) for non-split systems.
1727+
"""
1728+
function get_p(sys::AbstractSystem, varmap; split = is_split(sys), kwargs...)
17661729
dvs = unknowns(sys)
1767-
ps = parameters(sys)
1768-
defs = defaults(sys)
1769-
if t0 !== nothing
1770-
defs[get_iv(sys)] = t0
1771-
end
1772-
if parammap !== nothing
1773-
defs = mergedefaults(defs, parammap, ps)
1774-
end
1775-
1776-
# Convert observed equations "lhs ~ rhs" into defaults.
1777-
# Use the order "lhs => rhs" by default, but flip it to "rhs => lhs"
1778-
# if "lhs" is known by other means (parameter, another default, ...)
1779-
# TODO: Is there a better way to determine which equations to flip?
1780-
obs = map(x -> x.lhs => x.rhs, observed(sys))
1781-
obs = map(x -> x[1] in keys(defs) ? reverse(x) : x, obs)
1782-
obs = filter!(x -> !(x[1] isa Number), obs) # exclude e.g. "0 => x^2 + y^2 - 25"
1783-
obsmap = isempty(obs) ? Dict() : todict(obs)
1784-
1785-
defs = mergedefaults(defs, obsmap, u0map, dvs)
1786-
if symbolic_u0
1787-
u0 = varmap_to_vars(
1788-
u0map, dvs; defaults = defs, tofloat = false, use_union = false, toterm)
1730+
ps = parameters(sys; initial_parameters = true)
1731+
op = to_varmap(varmap, dvs)
1732+
add_observed!(sys, op)
1733+
add_parameter_dependencies!(sys, op)
1734+
_, missing_ps = build_operating_point!(
1735+
sys, op, Dict(), Dict(), defaults(sys), dvs, ps)
1736+
1737+
isempty(missing_ps) || throw(MissingParametersError(collect(missing_ps)))
1738+
1739+
if split
1740+
MTKParameters(sys, op; kwargs...)
17891741
else
1790-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union, toterm)
1742+
varmap_to_vars(op, ps; kwargs...)
17911743
end
1792-
t0 !== nothing && delete!(defs, get_iv(sys))
1793-
return u0, defs
17941744
end

0 commit comments

Comments
 (0)