Skip to content

Commit 78a0650

Browse files
refactor: move get_u0_p and get_u0 to problem_utils.jl
1 parent 3bf85c8 commit 78a0650

File tree

2 files changed

+95
-91
lines changed

2 files changed

+95
-91
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 0 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -702,97 +702,6 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
702702
!linenumbers ? Base.remove_linenums!(ex) : ex
703703
end
704704

705-
"""
706-
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=true, tofloat=true)
707-
708-
Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
709-
"""
710-
function get_u0_p(sys,
711-
u0map,
712-
parammap = nothing;
713-
t0 = nothing,
714-
use_union = true,
715-
tofloat = true,
716-
symbolic_u0 = false)
717-
dvs = unknowns(sys)
718-
ps = parameters(sys)
719-
720-
defs = defaults(sys)
721-
if t0 !== nothing
722-
defs[get_iv(sys)] = t0
723-
end
724-
if parammap !== nothing
725-
defs = mergedefaults(defs, parammap, ps)
726-
end
727-
if u0map isa Vector && eltype(u0map) <: Pair
728-
u0map = Dict(u0map)
729-
end
730-
if u0map isa Dict
731-
allobs = Set(getproperty.(observed(sys), :lhs))
732-
if any(in(allobs), keys(u0map))
733-
u0s_in_obs = filter(in(allobs), keys(u0map))
734-
@warn "Observed variables cannot be assigned initial values. Initial values for $u0s_in_obs will be ignored."
735-
end
736-
end
737-
obs = filter!(x -> !(x[1] isa Number), map(x -> x.rhs => x.lhs, observed(sys)))
738-
observedmap = isempty(obs) ? Dict() : todict(obs)
739-
defs = mergedefaults(defs, observedmap, u0map, dvs)
740-
for (k, v) in defs
741-
if Symbolics.isarraysymbolic(k)
742-
ks = scalarize(k)
743-
length(ks) == length(v) || error("$k has default value $v with unmatched size")
744-
for (kk, vv) in zip(ks, v)
745-
if !haskey(defs, kk)
746-
defs[kk] = vv
747-
end
748-
end
749-
end
750-
end
751-
752-
if symbolic_u0
753-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
754-
else
755-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union)
756-
end
757-
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
758-
p = p === nothing ? SciMLBase.NullParameters() : p
759-
t0 !== nothing && delete!(defs, get_iv(sys))
760-
u0, p, defs
761-
end
762-
763-
function get_u0(
764-
sys, u0map, parammap = nothing; symbolic_u0 = false,
765-
toterm = default_toterm, t0 = nothing, use_union = true)
766-
dvs = unknowns(sys)
767-
ps = parameters(sys)
768-
defs = defaults(sys)
769-
if t0 !== nothing
770-
defs[get_iv(sys)] = t0
771-
end
772-
if parammap !== nothing
773-
defs = mergedefaults(defs, parammap, ps)
774-
end
775-
776-
# Convert observed equations "lhs ~ rhs" into defaults.
777-
# Use the order "lhs => rhs" by default, but flip it to "rhs => lhs"
778-
# if "lhs" is known by other means (parameter, another default, ...)
779-
# TODO: Is there a better way to determine which equations to flip?
780-
obs = map(x -> x.lhs => x.rhs, observed(sys))
781-
obs = map(x -> x[1] in keys(defs) ? reverse(x) : x, obs)
782-
obs = filter!(x -> !(x[1] isa Number), obs) # exclude e.g. "0 => x^2 + y^2 - 25"
783-
obsmap = isempty(obs) ? Dict() : todict(obs)
784-
785-
defs = mergedefaults(defs, obsmap, u0map, dvs)
786-
if symbolic_u0
787-
u0 = varmap_to_vars(
788-
u0map, dvs; defaults = defs, tofloat = false, use_union = false, toterm)
789-
else
790-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union, toterm)
791-
end
792-
t0 !== nothing && delete!(defs, get_iv(sys))
793-
return u0, defs
794-
end
795-
796705
function ODEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)
797706
ODEFunctionExpr{true}(sys, args...; kwargs...)
798707
end

src/systems/problem_utils.jl

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,3 +510,98 @@ function process_SciMLProblem(
510510
kwargs...)
511511
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
512512
end
513+
514+
##############
515+
# Legacy functions for backward compatibility
516+
##############
517+
518+
"""
519+
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=true, tofloat=true)
520+
521+
Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
522+
"""
523+
function get_u0_p(sys,
524+
u0map,
525+
parammap = nothing;
526+
t0 = nothing,
527+
use_union = true,
528+
tofloat = true,
529+
symbolic_u0 = false)
530+
dvs = unknowns(sys)
531+
ps = parameters(sys)
532+
533+
defs = defaults(sys)
534+
if t0 !== nothing
535+
defs[get_iv(sys)] = t0
536+
end
537+
if parammap !== nothing
538+
defs = mergedefaults(defs, parammap, ps)
539+
end
540+
if u0map isa Vector && eltype(u0map) <: Pair
541+
u0map = Dict(u0map)
542+
end
543+
if u0map isa Dict
544+
allobs = Set(getproperty.(observed(sys), :lhs))
545+
if any(in(allobs), keys(u0map))
546+
u0s_in_obs = filter(in(allobs), keys(u0map))
547+
@warn "Observed variables cannot be assigned initial values. Initial values for $u0s_in_obs will be ignored."
548+
end
549+
end
550+
obs = filter!(x -> !(x[1] isa Number), map(x -> x.rhs => x.lhs, observed(sys)))
551+
observedmap = isempty(obs) ? Dict() : todict(obs)
552+
defs = mergedefaults(defs, observedmap, u0map, dvs)
553+
for (k, v) in defs
554+
if Symbolics.isarraysymbolic(k)
555+
ks = scalarize(k)
556+
length(ks) == length(v) || error("$k has default value $v with unmatched size")
557+
for (kk, vv) in zip(ks, v)
558+
if !haskey(defs, kk)
559+
defs[kk] = vv
560+
end
561+
end
562+
end
563+
end
564+
565+
if symbolic_u0
566+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
567+
else
568+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union)
569+
end
570+
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
571+
p = p === nothing ? SciMLBase.NullParameters() : p
572+
t0 !== nothing && delete!(defs, get_iv(sys))
573+
u0, p, defs
574+
end
575+
576+
function get_u0(
577+
sys, u0map, parammap = nothing; symbolic_u0 = false,
578+
toterm = default_toterm, t0 = nothing, use_union = true)
579+
dvs = unknowns(sys)
580+
ps = parameters(sys)
581+
defs = defaults(sys)
582+
if t0 !== nothing
583+
defs[get_iv(sys)] = t0
584+
end
585+
if parammap !== nothing
586+
defs = mergedefaults(defs, parammap, ps)
587+
end
588+
589+
# Convert observed equations "lhs ~ rhs" into defaults.
590+
# Use the order "lhs => rhs" by default, but flip it to "rhs => lhs"
591+
# if "lhs" is known by other means (parameter, another default, ...)
592+
# TODO: Is there a better way to determine which equations to flip?
593+
obs = map(x -> x.lhs => x.rhs, observed(sys))
594+
obs = map(x -> x[1] in keys(defs) ? reverse(x) : x, obs)
595+
obs = filter!(x -> !(x[1] isa Number), obs) # exclude e.g. "0 => x^2 + y^2 - 25"
596+
obsmap = isempty(obs) ? Dict() : todict(obs)
597+
598+
defs = mergedefaults(defs, obsmap, u0map, dvs)
599+
if symbolic_u0
600+
u0 = varmap_to_vars(
601+
u0map, dvs; defaults = defs, tofloat = false, use_union = false, toterm)
602+
else
603+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union, toterm)
604+
end
605+
t0 !== nothing && delete!(defs, get_iv(sys))
606+
return u0, defs
607+
end

0 commit comments

Comments
 (0)