Skip to content

Commit 3c43431

Browse files
refactor: move get_u0_p and get_u0 to problem_utils.jl
1 parent 82d815c commit 3c43431

File tree

2 files changed

+95
-91
lines changed

2 files changed

+95
-91
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 0 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -702,97 +702,6 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
702702
!linenumbers ? Base.remove_linenums!(ex) : ex
703703
end
704704

705-
"""
706-
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=true, tofloat=true)
707-
708-
Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
709-
"""
710-
function get_u0_p(sys,
711-
u0map,
712-
parammap = nothing;
713-
t0 = nothing,
714-
use_union = true,
715-
tofloat = true,
716-
symbolic_u0 = false)
717-
dvs = unknowns(sys)
718-
ps = parameters(sys)
719-
720-
defs = defaults(sys)
721-
if t0 !== nothing
722-
defs[get_iv(sys)] = t0
723-
end
724-
if parammap !== nothing
725-
defs = mergedefaults(defs, parammap, ps)
726-
end
727-
if u0map isa Vector && eltype(u0map) <: Pair
728-
u0map = Dict(u0map)
729-
end
730-
if u0map isa Dict
731-
allobs = Set(getproperty.(observed(sys), :lhs))
732-
if any(in(allobs), keys(u0map))
733-
u0s_in_obs = filter(in(allobs), keys(u0map))
734-
@warn "Observed variables cannot be assigned initial values. Initial values for $u0s_in_obs will be ignored."
735-
end
736-
end
737-
obs = filter!(x -> !(x[1] isa Number), map(x -> x.rhs => x.lhs, observed(sys)))
738-
observedmap = isempty(obs) ? Dict() : todict(obs)
739-
defs = mergedefaults(defs, observedmap, u0map, dvs)
740-
for (k, v) in defs
741-
if Symbolics.isarraysymbolic(k)
742-
ks = scalarize(k)
743-
length(ks) == length(v) || error("$k has default value $v with unmatched size")
744-
for (kk, vv) in zip(ks, v)
745-
if !haskey(defs, kk)
746-
defs[kk] = vv
747-
end
748-
end
749-
end
750-
end
751-
752-
if symbolic_u0
753-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
754-
else
755-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union)
756-
end
757-
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
758-
p = p === nothing ? SciMLBase.NullParameters() : p
759-
t0 !== nothing && delete!(defs, get_iv(sys))
760-
u0, p, defs
761-
end
762-
763-
function get_u0(
764-
sys, u0map, parammap = nothing; symbolic_u0 = false,
765-
toterm = default_toterm, t0 = nothing, use_union = true)
766-
dvs = unknowns(sys)
767-
ps = parameters(sys)
768-
defs = defaults(sys)
769-
if t0 !== nothing
770-
defs[get_iv(sys)] = t0
771-
end
772-
if parammap !== nothing
773-
defs = mergedefaults(defs, parammap, ps)
774-
end
775-
776-
# Convert observed equations "lhs ~ rhs" into defaults.
777-
# Use the order "lhs => rhs" by default, but flip it to "rhs => lhs"
778-
# if "lhs" is known by other means (parameter, another default, ...)
779-
# TODO: Is there a better way to determine which equations to flip?
780-
obs = map(x -> x.lhs => x.rhs, observed(sys))
781-
obs = map(x -> x[1] in keys(defs) ? reverse(x) : x, obs)
782-
obs = filter!(x -> !(x[1] isa Number), obs) # exclude e.g. "0 => x^2 + y^2 - 25"
783-
obsmap = isempty(obs) ? Dict() : todict(obs)
784-
785-
defs = mergedefaults(defs, obsmap, u0map, dvs)
786-
if symbolic_u0
787-
u0 = varmap_to_vars(
788-
u0map, dvs; defaults = defs, tofloat = false, use_union = false, toterm)
789-
else
790-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union, toterm)
791-
end
792-
t0 !== nothing && delete!(defs, get_iv(sys))
793-
return u0, defs
794-
end
795-
796705
function ODEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)
797706
ODEFunctionExpr{true}(sys, args...; kwargs...)
798707
end

src/systems/problem_utils.jl

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,3 +501,98 @@ function process_SciMLProblem(
501501
kwargs...)
502502
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
503503
end
504+
505+
##############
506+
# Legacy functions for backward compatibility
507+
##############
508+
509+
"""
510+
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=true, tofloat=true)
511+
512+
Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
513+
"""
514+
function get_u0_p(sys,
515+
u0map,
516+
parammap = nothing;
517+
t0 = nothing,
518+
use_union = true,
519+
tofloat = true,
520+
symbolic_u0 = false)
521+
dvs = unknowns(sys)
522+
ps = parameters(sys)
523+
524+
defs = defaults(sys)
525+
if t0 !== nothing
526+
defs[get_iv(sys)] = t0
527+
end
528+
if parammap !== nothing
529+
defs = mergedefaults(defs, parammap, ps)
530+
end
531+
if u0map isa Vector && eltype(u0map) <: Pair
532+
u0map = Dict(u0map)
533+
end
534+
if u0map isa Dict
535+
allobs = Set(getproperty.(observed(sys), :lhs))
536+
if any(in(allobs), keys(u0map))
537+
u0s_in_obs = filter(in(allobs), keys(u0map))
538+
@warn "Observed variables cannot be assigned initial values. Initial values for $u0s_in_obs will be ignored."
539+
end
540+
end
541+
obs = filter!(x -> !(x[1] isa Number), map(x -> x.rhs => x.lhs, observed(sys)))
542+
observedmap = isempty(obs) ? Dict() : todict(obs)
543+
defs = mergedefaults(defs, observedmap, u0map, dvs)
544+
for (k, v) in defs
545+
if Symbolics.isarraysymbolic(k)
546+
ks = scalarize(k)
547+
length(ks) == length(v) || error("$k has default value $v with unmatched size")
548+
for (kk, vv) in zip(ks, v)
549+
if !haskey(defs, kk)
550+
defs[kk] = vv
551+
end
552+
end
553+
end
554+
end
555+
556+
if symbolic_u0
557+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
558+
else
559+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union)
560+
end
561+
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
562+
p = p === nothing ? SciMLBase.NullParameters() : p
563+
t0 !== nothing && delete!(defs, get_iv(sys))
564+
u0, p, defs
565+
end
566+
567+
function get_u0(
568+
sys, u0map, parammap = nothing; symbolic_u0 = false,
569+
toterm = default_toterm, t0 = nothing, use_union = true)
570+
dvs = unknowns(sys)
571+
ps = parameters(sys)
572+
defs = defaults(sys)
573+
if t0 !== nothing
574+
defs[get_iv(sys)] = t0
575+
end
576+
if parammap !== nothing
577+
defs = mergedefaults(defs, parammap, ps)
578+
end
579+
580+
# Convert observed equations "lhs ~ rhs" into defaults.
581+
# Use the order "lhs => rhs" by default, but flip it to "rhs => lhs"
582+
# if "lhs" is known by other means (parameter, another default, ...)
583+
# TODO: Is there a better way to determine which equations to flip?
584+
obs = map(x -> x.lhs => x.rhs, observed(sys))
585+
obs = map(x -> x[1] in keys(defs) ? reverse(x) : x, obs)
586+
obs = filter!(x -> !(x[1] isa Number), obs) # exclude e.g. "0 => x^2 + y^2 - 25"
587+
obsmap = isempty(obs) ? Dict() : todict(obs)
588+
589+
defs = mergedefaults(defs, obsmap, u0map, dvs)
590+
if symbolic_u0
591+
u0 = varmap_to_vars(
592+
u0map, dvs; defaults = defs, tofloat = false, use_union = false, toterm)
593+
else
594+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union, toterm)
595+
end
596+
t0 !== nothing && delete!(defs, get_iv(sys))
597+
return u0, defs
598+
end

0 commit comments

Comments
 (0)