Skip to content

Commit 757c92a

Browse files
refactor: move get_u0_p and get_u0 to problem_utils.jl
1 parent 841dca8 commit 757c92a

File tree

2 files changed

+95
-91
lines changed

2 files changed

+95
-91
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 0 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -702,97 +702,6 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
702702
!linenumbers ? Base.remove_linenums!(ex) : ex
703703
end
704704

705-
"""
706-
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=true, tofloat=true)
707-
708-
Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
709-
"""
710-
function get_u0_p(sys,
711-
u0map,
712-
parammap = nothing;
713-
t0 = nothing,
714-
use_union = true,
715-
tofloat = true,
716-
symbolic_u0 = false)
717-
dvs = unknowns(sys)
718-
ps = parameters(sys)
719-
720-
defs = defaults(sys)
721-
if t0 !== nothing
722-
defs[get_iv(sys)] = t0
723-
end
724-
if parammap !== nothing
725-
defs = mergedefaults(defs, parammap, ps)
726-
end
727-
if u0map isa Vector && eltype(u0map) <: Pair
728-
u0map = Dict(u0map)
729-
end
730-
if u0map isa Dict
731-
allobs = Set(getproperty.(observed(sys), :lhs))
732-
if any(in(allobs), keys(u0map))
733-
u0s_in_obs = filter(in(allobs), keys(u0map))
734-
@warn "Observed variables cannot be assigned initial values. Initial values for $u0s_in_obs will be ignored."
735-
end
736-
end
737-
obs = filter!(x -> !(x[1] isa Number), map(x -> x.rhs => x.lhs, observed(sys)))
738-
observedmap = isempty(obs) ? Dict() : todict(obs)
739-
defs = mergedefaults(defs, observedmap, u0map, dvs)
740-
for (k, v) in defs
741-
if Symbolics.isarraysymbolic(k)
742-
ks = scalarize(k)
743-
length(ks) == length(v) || error("$k has default value $v with unmatched size")
744-
for (kk, vv) in zip(ks, v)
745-
if !haskey(defs, kk)
746-
defs[kk] = vv
747-
end
748-
end
749-
end
750-
end
751-
752-
if symbolic_u0
753-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
754-
else
755-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union)
756-
end
757-
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
758-
p = p === nothing ? SciMLBase.NullParameters() : p
759-
t0 !== nothing && delete!(defs, get_iv(sys))
760-
u0, p, defs
761-
end
762-
763-
function get_u0(
764-
sys, u0map, parammap = nothing; symbolic_u0 = false,
765-
toterm = default_toterm, t0 = nothing, use_union = true)
766-
dvs = unknowns(sys)
767-
ps = parameters(sys)
768-
defs = defaults(sys)
769-
if t0 !== nothing
770-
defs[get_iv(sys)] = t0
771-
end
772-
if parammap !== nothing
773-
defs = mergedefaults(defs, parammap, ps)
774-
end
775-
776-
# Convert observed equations "lhs ~ rhs" into defaults.
777-
# Use the order "lhs => rhs" by default, but flip it to "rhs => lhs"
778-
# if "lhs" is known by other means (parameter, another default, ...)
779-
# TODO: Is there a better way to determine which equations to flip?
780-
obs = map(x -> x.lhs => x.rhs, observed(sys))
781-
obs = map(x -> x[1] in keys(defs) ? reverse(x) : x, obs)
782-
obs = filter!(x -> !(x[1] isa Number), obs) # exclude e.g. "0 => x^2 + y^2 - 25"
783-
obsmap = isempty(obs) ? Dict() : todict(obs)
784-
785-
defs = mergedefaults(defs, obsmap, u0map, dvs)
786-
if symbolic_u0
787-
u0 = varmap_to_vars(
788-
u0map, dvs; defaults = defs, tofloat = false, use_union = false, toterm)
789-
else
790-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union, toterm)
791-
end
792-
t0 !== nothing && delete!(defs, get_iv(sys))
793-
return u0, defs
794-
end
795-
796705
function ODEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)
797706
ODEFunctionExpr{true}(sys, args...; kwargs...)
798707
end

src/systems/problem_utils.jl

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,3 +509,98 @@ function process_SciMLProblem(
509509
kwargs...)
510510
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
511511
end
512+
513+
##############
514+
# Legacy functions for backward compatibility
515+
##############
516+
517+
"""
518+
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=true, tofloat=true)
519+
520+
Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
521+
"""
522+
function get_u0_p(sys,
523+
u0map,
524+
parammap = nothing;
525+
t0 = nothing,
526+
use_union = true,
527+
tofloat = true,
528+
symbolic_u0 = false)
529+
dvs = unknowns(sys)
530+
ps = parameters(sys)
531+
532+
defs = defaults(sys)
533+
if t0 !== nothing
534+
defs[get_iv(sys)] = t0
535+
end
536+
if parammap !== nothing
537+
defs = mergedefaults(defs, parammap, ps)
538+
end
539+
if u0map isa Vector && eltype(u0map) <: Pair
540+
u0map = Dict(u0map)
541+
end
542+
if u0map isa Dict
543+
allobs = Set(getproperty.(observed(sys), :lhs))
544+
if any(in(allobs), keys(u0map))
545+
u0s_in_obs = filter(in(allobs), keys(u0map))
546+
@warn "Observed variables cannot be assigned initial values. Initial values for $u0s_in_obs will be ignored."
547+
end
548+
end
549+
obs = filter!(x -> !(x[1] isa Number), map(x -> x.rhs => x.lhs, observed(sys)))
550+
observedmap = isempty(obs) ? Dict() : todict(obs)
551+
defs = mergedefaults(defs, observedmap, u0map, dvs)
552+
for (k, v) in defs
553+
if Symbolics.isarraysymbolic(k)
554+
ks = scalarize(k)
555+
length(ks) == length(v) || error("$k has default value $v with unmatched size")
556+
for (kk, vv) in zip(ks, v)
557+
if !haskey(defs, kk)
558+
defs[kk] = vv
559+
end
560+
end
561+
end
562+
end
563+
564+
if symbolic_u0
565+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
566+
else
567+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union)
568+
end
569+
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
570+
p = p === nothing ? SciMLBase.NullParameters() : p
571+
t0 !== nothing && delete!(defs, get_iv(sys))
572+
u0, p, defs
573+
end
574+
575+
function get_u0(
576+
sys, u0map, parammap = nothing; symbolic_u0 = false,
577+
toterm = default_toterm, t0 = nothing, use_union = true)
578+
dvs = unknowns(sys)
579+
ps = parameters(sys)
580+
defs = defaults(sys)
581+
if t0 !== nothing
582+
defs[get_iv(sys)] = t0
583+
end
584+
if parammap !== nothing
585+
defs = mergedefaults(defs, parammap, ps)
586+
end
587+
588+
# Convert observed equations "lhs ~ rhs" into defaults.
589+
# Use the order "lhs => rhs" by default, but flip it to "rhs => lhs"
590+
# if "lhs" is known by other means (parameter, another default, ...)
591+
# TODO: Is there a better way to determine which equations to flip?
592+
obs = map(x -> x.lhs => x.rhs, observed(sys))
593+
obs = map(x -> x[1] in keys(defs) ? reverse(x) : x, obs)
594+
obs = filter!(x -> !(x[1] isa Number), obs) # exclude e.g. "0 => x^2 + y^2 - 25"
595+
obsmap = isempty(obs) ? Dict() : todict(obs)
596+
597+
defs = mergedefaults(defs, obsmap, u0map, dvs)
598+
if symbolic_u0
599+
u0 = varmap_to_vars(
600+
u0map, dvs; defaults = defs, tofloat = false, use_union = false, toterm)
601+
else
602+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union, toterm)
603+
end
604+
t0 !== nothing && delete!(defs, get_iv(sys))
605+
return u0, defs
606+
end

0 commit comments

Comments
 (0)