Skip to content

Commit 2a25200

Browse files
committed
extend BVProblem for constraint equations
1 parent d23d6f7 commit 2a25200

File tree

2 files changed

+160
-4
lines changed

2 files changed

+160
-4
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,7 @@ function SciMLBase.BVProblem(sys::AbstractODESystem,
873873
kwargs...)
874874
BVProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
875875
end
876-
876+
o
877877
function SciMLBase.BVProblem{true}(sys::AbstractODESystem, args...; kwargs...)
878878
BVProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
879879
end
@@ -908,18 +908,84 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
908908
kwargs1 = merge(kwargs1, (callback = cbs,))
909909
end
910910

911+
# Handle algebraic equations
912+
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
913+
pidxmap = Dict([v => i for (i, v) in enumerate(parameters(sys))])
914+
ns = length(stmap)
915+
ne = length(get_alg_eqs(sys))
916+
911917
# Define the boundary conditions.
912-
bc = if iip
913-
(residual, u, p, t) -> (residual .= u[1] .- u0)
918+
bc = if has_alg_eqs(sys)
919+
if iip
920+
(residual,u,p,t) -> begin
921+
residual[1:ns] .= u[1] .- u0
922+
residual[ns+1:ns+ne] .= sub_u_p_into_symeq.(get_alg_eqs(sys))
923+
end
924+
else
925+
(u,p,t) -> begin
926+
resid = vcat(u[1] - u0, sub_u_p_into_symeq.(get_alg_eqs(sys)))
927+
end
928+
end
914929
else
915-
(u, p, t) -> (u[1] - u0)
930+
if iip
931+
(residual,u,p,t) -> begin
932+
residual .= u[1] .- u0
933+
end
934+
else
935+
(u,p,t) -> (u[1] - u0)
936+
end
916937
end
917938

918939
return BVProblem{iip}(f, bc, u0, tspan, p; kwargs1..., kwargs...)
919940
end
920941

921942
get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
922943

944+
# Helper to create the dictionary that will substitute numeric values for u, p into the algebraic equations in the ODESystem. Used to construct the boundary condition function.
945+
# Take a system with variables x,y, parameters g
946+
#
947+
# 1 + x + y → 1 + u[1][1] + u[1][2]
948+
# x(0.5) → u(0.5)[1]
949+
# x(0.5)*g(0.5) → u(0.5)[1]*p[1]
950+
951+
function sub_u_p_into_symeq(eq, u, p, stidxmap, pidxmap)
952+
iv = ModelingToolkit.get_iv(sys)
953+
eq = Symbolics.unwrap(eq)
954+
955+
stmap = Dict([st => u[1][i] for st => i in stidxmap])
956+
pmap = Dict([pa => p[i] for pa => i in pidxmap])
957+
eq = Symbolics.substitute(eq, merge(stmap, pmap))
958+
959+
csyms = []
960+
# Find most nested calls, substitute those first.
961+
while !isempty(find_callable_syms!(csyms, eq))
962+
for sym in csyms
963+
t = arguments(sym)[1]
964+
x = operation(sym)
965+
966+
if isparameter(x)
967+
eq = Symbolics.substitute(eq, Dict(x(t) => p[pidxmap[x(iv)]]))
968+
elseif isvariable(x)
969+
eq = Symbolics.substitute(eq, Dict(x(t) => u(val)[stidxmap[x(iv)]]))
970+
end
971+
end
972+
empty!(csyms)
973+
end
974+
eq
975+
end
976+
977+
function find_callable_syms!(csyms, ex)
978+
ex = Symbolics.unwrap(ex)
979+
980+
if iscall(ex)
981+
operation(ex) isa Symbolic && (arguments(ex)[1] isa Symbolic) && push!(csyms, ex) # only add leaf nodes
982+
for arg in arguments(ex)
983+
find_callable_syms!(csyms, arg)
984+
end
985+
end
986+
csyms
987+
end
988+
923989
"""
924990
```julia
925991
DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan,

test/bvproblem.jl

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using BoundaryValueDiffEq, OrdinaryDiffEq
22
using ModelingToolkit
33
using ModelingToolkit: t_nounits as t, D_nounits as D
44

5+
### Test Collocation solvers on simple problems
56
solvers = [MIRK4, RadauIIa5, LobattoIIIa3]
67

78
@parameters α=7.5 β=4.0 γ=8.0 δ=5.0
@@ -68,3 +69,92 @@ for solver in solvers
6869
@test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
6970
@test sol.u[1] ==/ 2, π / 2]
7071
end
72+
73+
###################################################
74+
### TESTING ODESystem with Constraint Equations ###
75+
###################################################
76+
77+
# Cartesian pendulum from the docs. Testing that initialization is satisfied.
78+
let
79+
@parameters g
80+
@variables x(t) y(t) [state_priority = 10] λ(t)
81+
eqs = [D(D(x)) ~ λ * x
82+
D(D(y)) ~ λ * y - g
83+
x^2 + y^2 ~ 1]
84+
@mtkbuild pend = ODESystem(eqs, t)
85+
86+
tspan = (0.0, 1.5)
87+
u0map = [x => 1, y => 0]
88+
parammap = [g => 1]
89+
guesses ==> 1]
90+
91+
prob = ODEProblem(pend, u0map, tspan, pmap; guesses)
92+
sol = solve(prob, Rodas5P())
93+
94+
bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses)
95+
96+
for solver in solvers
97+
sol = solve(bvp, solver(), dt = 0.01)
98+
@test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
99+
conditions = getfield.(equations(pend)[3:end], :rhs)
100+
@test [sol[conditions][1]; sol[x][1] - 1; sol[y][1]] 0
101+
end
102+
103+
bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
104+
for solver in solvers
105+
sol = solve(bvp, solver(), dt = 0.01)
106+
@test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
107+
conditions = getfield.(equations(pend)[3:end], :rhs)
108+
@test [sol[conditions][1]; sol[x][1] - 1; sol[y][1]] 0
109+
end
110+
end
111+
112+
# Adding a midpoint boundary condition.
113+
let
114+
@parameters g
115+
@variables x(..) y(t) [state_priority = 10] λ(t)
116+
eqs = [D(D(x(t))) ~ λ * x(t)
117+
D(D(y)) ~ λ * y - g
118+
x(t)^2 + y^2 ~ 1
119+
x(0.5) ~ 1]
120+
@mtkbuild pend = ODESystem(eqs, t)
121+
122+
tspan = (0.0, 1.5)
123+
u0map = [x(t) => 0.6, y => 0.8]
124+
parammap = [g => 1]
125+
guesses ==> 1]
126+
127+
prob = ODEProblem(pend, u0map, tspan, pmap; guesses, check_length = false)
128+
sol = solve(prob, Rodas5P())
129+
130+
bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses = guesses, check_length = false)
131+
132+
for solver in solvers
133+
sol = solve(bvp, solver(), dt = 0.01)
134+
@test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
135+
conditions = getfield.(equations(pend)[3:end], :rhs)
136+
@test [sol[conditions][1]; sol[x][1] - 1; sol[y][1]] 0
137+
@test sol.u[1] ==/ 2, π / 2]
138+
end
139+
140+
bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses)
141+
142+
for solver in solvers
143+
sol = solve(bvp, solver(), dt = 0.01)
144+
@test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
145+
conditions = getfield.(equations(pend)[3:end], :rhs)
146+
@test [sol[conditions][1]; sol[x][1] - 1; sol[y][1]] 0
147+
end
148+
149+
bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
150+
for solver in solvers
151+
sol = solve(bvp, solver(), dt = 0.01)
152+
@test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
153+
conditions = getfield.(equations(pend)[3:end], :rhs)
154+
@test [sol[conditions][1]; sol[x][1] - 1; sol[y][1]] 0
155+
end
156+
end
157+
158+
# Testing a more complicated case with multiple constraints.
159+
let
160+
end

0 commit comments

Comments
 (0)