Skip to content

Commit 50504ab

Browse files
committed
adding tests
1 parent 2a25200 commit 50504ab

File tree

3 files changed

+261
-127
lines changed

3 files changed

+261
-127
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ ArrayInterface = "6, 7"
8282
BifurcationKit = "0.4"
8383
BlockArrays = "1.1"
8484
BoundaryValueDiffEq = "5.12.0"
85+
BoundaryValueDiffEqAscher = "1.1.0"
8586
ChainRulesCore = "1"
8687
Combinatorics = "1"
8788
CommonSolve = "0.2.4"
@@ -140,8 +141,8 @@ SimpleNonlinearSolve = "0.1.0, 1, 2"
140141
SparseArrays = "1"
141142
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
142143
StaticArrays = "0.10, 0.11, 0.12, 1.0"
143-
StochasticDiffEq = "6.72.1"
144144
StochasticDelayDiffEq = "1.8.1"
145+
StochasticDiffEq = "6.72.1"
145146
SymbolicIndexingInterface = "0.3.36"
146147
SymbolicUtils = "3.7"
147148
Symbolics = "6.19"
@@ -154,6 +155,7 @@ julia = "1.9"
154155
AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
155156
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
156157
BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
158+
BoundaryValueDiffEqAscher = "7227322d-7511-4e07-9247-ad6ff830280e"
157159
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
158160
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
159161
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
@@ -185,4 +187,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
185187
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
186188

187189
[targets]
188-
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve"]
190+
test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEq", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve"]

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 110 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -853,15 +853,47 @@ get_callback(prob::ODEProblem) = prob.kwargs[:callback]
853853
```julia
854854
SciMLBase.BVProblem{iip}(sys::AbstractODESystem, u0map, tspan,
855855
parammap = DiffEqBase.NullParameters();
856+
constraints = nothing, guesses = nothing,
856857
version = nothing, tgrad = false,
857858
jac = true, sparse = true,
858859
simplify = false,
859860
kwargs...) where {iip}
860861
```
861862
862-
Create a `BVProblem` from the [`ODESystem`](@ref). The arguments `dvs` and
863+
Create a boundary value problem from the [`ODESystem`](@ref). The arguments `dvs` and
863864
`ps` are used to set the order of the dependent variable and parameter vectors,
864-
respectively. `u0map` should be used to specify the initial condition.
865+
respectively. `u0map` is used to specify fixed initial values for the states.
866+
867+
Every variable must have either an initial guess supplied using `guesses` or
868+
a fixed initial value specified using `u0map`.
869+
870+
`constraints` are used to specify boundary conditions to the ODESystem in the
871+
form of equations. These values should specify values that state variables should
872+
take at specific points, as in `x(0.5) ~ 1`). More general constraints that
873+
should hold over the entire solution, such as `x(t)^2 + y(t)^2`, should be
874+
specified as one of the equations used to build the `ODESystem`. Below is an example.
875+
876+
```julia
877+
@parameters g
878+
@variables x(..) y(t) [state_priority = 10] λ(t)
879+
eqs = [D(D(x(t))) ~ λ * x(t)
880+
D(D(y)) ~ λ * y - g
881+
x(t)^2 + y^2 ~ 1]
882+
@mtkbuild pend = ODESystem(eqs, t)
883+
884+
tspan = (0.0, 1.5)
885+
u0map = [x(t) => 0.6, y => 0.8]
886+
parammap = [g => 1]
887+
guesses = [λ => 1]
888+
constraints = [x(0.5) ~ 1]
889+
890+
bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
891+
```
892+
893+
If no `constraints` are specified, the problem will be treated as an initial value problem.
894+
895+
If the `ODESystem` has algebraic equations like `x(t)^2 + y(t)^2`, the resulting
896+
`BVProblem` must be solved using BVDAE solvers, such as Ascher.
865897
"""
866898
function SciMLBase.BVProblem(sys::AbstractODESystem, args...; kwargs...)
867899
BVProblem{true}(sys, args...; kwargs...)
@@ -873,7 +905,7 @@ function SciMLBase.BVProblem(sys::AbstractODESystem,
873905
kwargs...)
874906
BVProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
875907
end
876-
o
908+
877909
function SciMLBase.BVProblem{true}(sys::AbstractODESystem, args...; kwargs...)
878910
BVProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
879911
end
@@ -885,45 +917,71 @@ end
885917
function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
886918
tspan = get_tspan(sys),
887919
parammap = DiffEqBase.NullParameters();
920+
constraints = nothing, guesses = nothing,
888921
version = nothing, tgrad = false,
889922
callback = nothing,
890923
check_length = true,
891924
warn_initialize_determined = true,
892925
eval_expression = false,
893926
eval_module = @__MODULE__,
894927
kwargs...) where {iip, specialize}
928+
895929
if !iscomplete(sys)
896930
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`")
897931
end
932+
!isnothing(callbacks) && error("BVP solvers do not support callbacks.")
898933

899-
f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
900-
t = tspan !== nothing ? tspan[1] : tspan,
901-
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
934+
iv = get_iv(sys)
935+
constraintsts = nothing
936+
constraintps = nothing
937+
sts = unknowns(sys)
938+
ps = parameters(sys)
902939

903-
cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
904-
kwargs = filter_kwargs(kwargs)
940+
if !isnothing(constraints)
941+
constraints isa Equation ||
942+
constraints isa Vector{Equation} ||
943+
error("Constraints must be specified as an equation or a vector of equations.")
905944

906-
kwargs1 = (;)
907-
if cbs !== nothing
908-
kwargs1 = merge(kwargs1, (callback = cbs,))
945+
(length(constraints) + length(u0map) > length(sts)) &&
946+
error("The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) cannot exceed the total number of states.")
947+
948+
constraintsts = OrderedSet()
949+
constraintps = OrderedSet()
950+
951+
for eq in constraints
952+
collect_vars!(constraintsts, constraintps, eq, iv)
953+
validate_constraint_syms(eq, constraintsts, constraintps, Set(sts), Set(ps), iv)
954+
empty!(constraintsts)
955+
empty!(constraintps)
956+
end
909957
end
910958

911-
# Handle algebraic equations
912-
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
913-
pidxmap = Dict([v => i for (i, v) in enumerate(parameters(sys))])
914-
ns = length(stmap)
915-
ne = length(get_alg_eqs(sys))
959+
f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
960+
t = tspan !== nothing ? tspan[1] : tspan, guesses,
961+
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
962+
963+
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
964+
pidxmap = Dict([v => i for (i, v) in enumerate(ps)])
965+
966+
# Indices of states that have initial constraints.
967+
u0i = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for k in keys(u0map)]
968+
ni = length(u0i)
916969

917-
# Define the boundary conditions.
918-
bc = if has_alg_eqs(sys)
970+
bc = if !isnothing(constraints)
971+
ne = length(constraints)
919972
if iip
920973
(residual,u,p,t) -> begin
921-
residual[1:ns] .= u[1] .- u0
922-
residual[ns+1:ns+ne] .= sub_u_p_into_symeq.(get_alg_eqs(sys))
974+
residual[1:ni] .= u[1][u0i] .- u0[u0i]
975+
residual[ni+1:ni+ne] .= map(constraints) do cons
976+
sub_u_p_into_symeq(cons.rhs - cons.lhs, u, p, stidxmap, pidxmap, iv, tspan)
977+
end
923978
end
924979
else
925980
(u,p,t) -> begin
926-
resid = vcat(u[1] - u0, sub_u_p_into_symeq.(get_alg_eqs(sys)))
981+
consresid = map(constraints) do cons
982+
sub_u_p_into_symeq(cons.rhs-cons.lhs, u, p, stidxmap, pidxmap, iv, tspan)
983+
end
984+
resid = vcat(u[1][u0i] - u0[u0i], consresid)
927985
end
928986
end
929987
else
@@ -941,32 +999,54 @@ end
941999

9421000
get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
9431001

944-
# Helper to create the dictionary that will substitute numeric values for u, p into the algebraic equations in the ODESystem. Used to construct the boundary condition function.
1002+
# Validate that all the variables in the BVP constraints are well-formed states or parameters.
1003+
function validate_constraint_syms(eq, constraintsts, constraintps, sts, ps, iv)
1004+
ModelingToolkit.check_variables(constraintsts)
1005+
ModelingToolkit.check_parameters(constraintps)
1006+
1007+
for var in constraintsts
1008+
if arguments(var) == iv
1009+
var sts || error("Constraint equation $eq contains a variable $var that is not a variable of the ODESystem.")
1010+
error("Constraint equation $eq contains a variable $var that does not have a specified argument. Such equations should be specified as algebraic equations to the ODESystem rather than a boundary constraints.")
1011+
else
1012+
operation(var)(iv) sts || error("Constraint equation $eq contains a variable $(operation(var)) that is not a variable of the ODESystem.")
1013+
end
1014+
end
1015+
1016+
for var in constraintps
1017+
if !iscall(var)
1018+
var ps || error("Constraint equation $eq contains a parameter $var that is not a parameter of the ODESystem.")
1019+
else
1020+
operation(var) ps || error("Constraint equations contain a parameter $var that is not a parameter of the ODESystem.")
1021+
end
1022+
end
1023+
end
1024+
1025+
# Helper to substitute numeric values for u, p into the algebraic equations in the ODESystem. Used to construct the boundary condition function.
9451026
# Take a system with variables x,y, parameters g
9461027
#
947-
# 1 + x + y → 1 + u[1][1] + u[1][2]
1028+
# 1 + x(0) + y(0) → 1 + u[1][1] + u[1][2]
9481029
# x(0.5) → u(0.5)[1]
9491030
# x(0.5)*g(0.5) → u(0.5)[1]*p[1]
950-
951-
function sub_u_p_into_symeq(eq, u, p, stidxmap, pidxmap)
952-
iv = ModelingToolkit.get_iv(sys)
1031+
function sub_u_p_into_symeq(eq, u, p, stidxmap, pidxmap, iv, tspan)
9531032
eq = Symbolics.unwrap(eq)
9541033

955-
stmap = Dict([st => u[1][i] for st => i in stidxmap])
956-
pmap = Dict([pa => p[i] for pa => i in pidxmap])
1034+
stmap = Dict([st => u[1][i] for (st, i) in stidxmap])
1035+
pmap = Dict([pa => p[i] for (pa, i) in pidxmap])
9571036
eq = Symbolics.substitute(eq, merge(stmap, pmap))
9581037

9591038
csyms = []
9601039
# Find most nested calls, substitute those first.
9611040
while !isempty(find_callable_syms!(csyms, eq))
9621041
for sym in csyms
963-
t = arguments(sym)[1]
9641042
x = operation(sym)
1043+
t = arguments(sym)[1]
1044+
prog = (tspan[2] - tspan[1])/(t - tspan[1]) # 1 / the % of the timespan elapsed
9651045

9661046
if isparameter(x)
9671047
eq = Symbolics.substitute(eq, Dict(x(t) => p[pidxmap[x(iv)]]))
9681048
elseif isvariable(x)
969-
eq = Symbolics.substitute(eq, Dict(x(t) => u(val)[stidxmap[x(iv)]]))
1049+
eq = Symbolics.substitute(eq, Dict(x(t) => u[Int(end ÷ prog)][stidxmap[x(iv)]]))
9701050
end
9711051
end
9721052
empty!(csyms)

0 commit comments

Comments
 (0)