Skip to content

Commit db5eb66

Browse files
committed
up
1 parent b83e003 commit db5eb66

File tree

2 files changed

+76
-101
lines changed

2 files changed

+76
-101
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,7 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
939939
ps = parameters(sys)
940940

941941
# Constraint validation
942-
f_cons = if !isnothing(constraints)
942+
if !isnothing(constraints)
943943
constraints isa Equation ||
944944
constraints isa Vector{Equation} ||
945945
error("Constraints must be specified as an equation or a vector of equations.")
@@ -958,32 +958,6 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
958958
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
959959
u0_idxs = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k,v) in u0map]
960960

961-
# bc = if !isnothing(constraints) && iip
962-
# (residual,u,p,t) -> begin
963-
# println(u(0.5))
964-
# residual[1:ni] .= u[1][u0i] .- u0[u0i]
965-
# for (i, cons) in enumerate(constraints)
966-
# residual[ni+i] = eval_symbolic_residual(cons, u, p, stidxmap, pidxmap, iv, tspan)
967-
# end
968-
# end
969-
970-
# elseif !isnothing(constraints) && !iip
971-
# (u,p,t) -> begin
972-
# consresid = map(constraints) do cons
973-
# eval_symbolic_residual(cons, u, p, stidxmap, pidxmap, iv, tspan)
974-
# end
975-
# resid = vcat(u[1][u0i] - u0[u0i], consresid)
976-
# end
977-
978-
# elseif iip
979-
# (residual,u,p,t) -> begin
980-
# println(u(0.5))
981-
# residual .= u[1] .- u0
982-
# end
983-
984-
# else
985-
# (u,p,t) -> (u[1] - u0)
986-
# end
987961
bc = process_constraints(sys, constraints, u0, u0_idxs, tspan, iip)
988962

989963
return BVProblem{iip}(f, bc, u0, tspan, p; kwargs...)
@@ -1075,11 +1049,10 @@ function process_constraints(sys::ODESystem, constraints, u0, u0_idxs, tspan, ii
10751049
end
10761050

10771051
exprs = vcat(init_cond_exprs, exprs)
1052+
@show exprs
10781053
bcs = Symbolics.build_function(exprs, sol, p, expression = Val{false})
10791054
if iip
1080-
return (resid, u, p, t) -> begin
1081-
bcs[2](resid, u, p)
1082-
end
1055+
return (resid, u, p, t) -> bcs[2](resid, u, p)
10831056
else
10841057
return (u, p, t) -> bcs[1](u, p)
10851058
end

test/bvproblem.jl

Lines changed: 73 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
using BoundaryValueDiffEq, OrdinaryDiffEq, BoundaryValueDiffEqAscher
44
using ModelingToolkit
5+
using SciMLBase
56
using ModelingToolkit: t_nounits as t, D_nounits as D
7+
import ModelingToolkit: process_constraints
68

79
### Test Collocation solvers on simple problems
810
solvers = [MIRK4, RadauIIa5, LobattoIIIa3]
@@ -81,46 +83,9 @@ end
8183
### ODESystem with constraint equations, DAEs with constraints ###
8284
##################################################################
8385

84-
# Cartesian pendulum from the docs.
85-
# DAE IVP solved using BoundaryValueDiffEq solvers.
86-
# let
87-
# @parameters g
88-
# @variables x(t) y(t) [state_priority = 10] λ(t)
89-
# eqs = [D(D(x)) ~ λ * x
90-
# D(D(y)) ~ λ * y - g
91-
# x^2 + y^2 ~ 1]
92-
# @mtkbuild pend = ODESystem(eqs, t)
93-
#
94-
# tspan = (0.0, 1.5)
95-
# u0map = [x => 1, y => 0]
96-
# pmap = [g => 1]
97-
# guess = [λ => 1]
98-
#
99-
# prob = ODEProblem(pend, u0map, tspan, pmap; guesses = guess)
100-
# osol = solve(prob, Rodas5P())
101-
#
102-
# zeta = [0., 0., 0., 0., 0.]
103-
# bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses = guess)
104-
#
105-
# for solver in solvers
106-
# sol = solve(bvp, solver(zeta), dt = 0.001)
107-
# @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
108-
# conditions = getfield.(equations(pend)[3:end], :rhs)
109-
# @test isapprox([sol[conditions][1]; sol[x][1] - 1; sol[y][1]], zeros(5), atol = 0.001)
110-
# end
111-
#
112-
# bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
113-
# for solver in solvers
114-
# sol = solve(bvp, solver(zeta), dt = 0.01)
115-
# @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
116-
# conditions = getfield.(equations(pend)[3:end], :rhs)
117-
# @test [sol[conditions][1]; sol[x][1] - 1; sol[y][1]] ≈ 0
118-
# end
119-
# end
120-
12186
# Test generation of boundary condition function.
12287
let
123-
@parameters α=7.5 β=4.0 γ=8.0 δ=5.0
88+
@parameters α=1.5 β=1.0 γ=3.0 δ=1.0
12489
@variables x(..) y(t)
12590
eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y,
12691
D(y) ~ -γ * y + δ * x(t) * y]
@@ -130,11 +95,11 @@ let
13095

13196
function lotkavolterra!(du, u, p, t)
13297
du[1] = p[1]*u[1] - p[2]*u[1]*u[2]
133-
du[2] = -p[3]*u[2] + p[4]*u[1]*u[2]
98+
du[2] = -p[4]*u[2] + p[3]*u[1]*u[2]
13499
end
135100

136101
function lotkavolterra(u, p, t)
137-
[p[1]*u[1] - p[2]*u[1]*u[2], -p[3]*u[2] + p[4]*u[1]*u[2]]
102+
[p[1]*u[1] - p[2]*u[1]*u[2], -p[4]*u[2] + p[3]*u[1]*u[2]]
138103
end
139104
# Compare the built bc function to the actual constructed one.
140105
function bc!(resid, u, p, t)
@@ -146,23 +111,22 @@ let
146111
[u[1][1] - 1., u[1][2] - 2.]
147112
end
148113

149-
constraints = nothing
150-
u0 = [1., 2.]; p = [7.5, 4., 8., 5.]
151-
genbc_iip = ModelingToolkit.process_constraints(lksys, constraints, u0, [1, 2], tspan, true)
152-
genbc_oop = ModelingToolkit.process_constraints(lksys, constraints, u0, [1, 2], tspan, false)
114+
u0 = [1., 2.]; p = [1.5, 1., 3., 1.]
115+
genbc_iip = ModelingToolkit.process_constraints(lksys, nothing, u0, [1, 2], tspan, true)
116+
genbc_oop = ModelingToolkit.process_constraints(lksys, nothing, u0, [1, 2], tspan, false)
153117

154-
bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, [1,2], tspan, p)
155-
bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, [1,2], tspan, p)
118+
bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, [1.,2.], tspan, p)
119+
bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, [1.,2.], tspan, p)
156120

157-
sol1 = solve(bvpi1, MIRK4(), dt = 0.01)
158-
sol2 = solve(bvpi2, MIRK4(), dt = 0.01)
121+
sol1 = solve(bvpi1, MIRK4(), dt = 0.05)
122+
sol2 = solve(bvpi2, MIRK4(), dt = 0.05)
159123
@test sol1 sol2
160124

161125
bvpo1 = BVProblem(lotkavolterra, bc, [1,2], tspan, p)
162126
bvpo2 = BVProblem(lotkavolterra, genbc_oop, [1,2], tspan, p)
163127

164-
sol1 = solve(bvpo1, MIRK4(), dt = 0.01)
165-
sol2 = solve(bvpo2, MIRK4(), dt = 0.01)
128+
sol1 = solve(bvpo1, MIRK4(), dt = 0.05)
129+
sol2 = solve(bvpo2, MIRK4(), dt = 0.05)
166130
@test sol1 sol2
167131

168132
# Test with a constraint.
@@ -180,28 +144,28 @@ let
180144
genbc_iip = ModelingToolkit.process_constraints(lksys, constraints, u0, [2], tspan, true)
181145
genbc_oop = ModelingToolkit.process_constraints(lksys, constraints, u0, [2], tspan, false)
182146

183-
bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, [1,2], tspan, p)
184-
bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, [1,2], tspan, p)
185-
186-
sol1 = solve(bvpi1, MIRK4(), dt = 0.01)
187-
sol2 = solve(bvpi2, MIRK4(), dt = 0.01)
147+
bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, [1.,2.], tspan, p)
148+
bvpi2 = SciMLBase.BVProblem(lotkavolterra!, genbc_iip, [1.,2.], tspan, p)
149+
bvpi3 = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan, parammap; guesses, constraints)
150+
151+
sol1 = solve(bvpi1, MIRK4(), dt = 0.05)
152+
sol2 = solve(bvpi2, MIRK4(), dt = 0.05)
188153
@test sol1 sol2 # don't get true equality here, not sure why
189154

190155
bvpo1 = BVProblem(lotkavolterra, bc, [1,2], tspan, p)
191156
bvpo2 = BVProblem(lotkavolterra, genbc_oop, [1,2], tspan, p)
192157

193-
sol1 = solve(bvpo1, MIRK4(), dt = 0.01)
194-
sol2 = solve(bvpo2, MIRK4(), dt = 0.01)
158+
sol1 = solve(bvpo1, MIRK4(), dt = 0.05)
159+
sol2 = solve(bvpo2, MIRK4(), dt = 0.05)
195160
@test sol1 sol2
196161
end
197162

198-
function test_solvers(solvers, prob, u0map, constraints, equations = []; dt = 0.01)
163+
function test_solvers(solvers, prob, u0map, constraints, equations = []; dt = 0.05)
199164
for solver in solvers
200165
sol = solve(prob, solver(); dt)
201-
@test successful_retcode(sol.retcode)
166+
@test SciMLBase.successful_retcode(sol.retcode)
202167
p = prob.p; t = sol.t; bc = prob.f.bc
203168
ns = length(prob.u0)
204-
205169
if isinplace(bvp.f)
206170
resid = zeros(ns)
207171
bc!(resid, sol, p, t)
@@ -226,45 +190,83 @@ end
226190

227191
# Simple ODESystem with BVP constraints.
228192
let
229-
@parameters α=7.5 β=4.0 γ=8.0 δ=5.0
193+
t = ModelingToolkit.t_nounits; D = ModelingToolkit.D_nounits
194+
@parameters α=1.5 β=1.0 γ=3.0 δ=1.0
230195
@variables x(..) y(t)
231196

232197
eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y,
233198
D(y) ~ -γ * y + δ * x(t) * y]
234199

235-
u0map = [y => 2.0]
200+
u0map = []
236201
parammap ==> 7.5, β => 4, γ => 8.0, δ => 5.0]
237202
tspan = (0.0, 10.0)
238-
guesses = [x(t) => 1.0]
203+
guesses = [x(t) => 1.0, y => 2.]
239204

240205
@mtkbuild lotkavolterra = ODESystem(eqs, t)
241206

242-
constraints = [x(6.) ~ 3]
207+
constraints = [x(6.) ~ 1.5]
243208
bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; guesses, constraints)
244209
test_solvers(solvers, bvp, u0map, constraints)
245210

246211
# Testing that more complicated constraints give correct solutions.
247-
constraints = [y(2.) + x(8.) ~ 12]
248-
bvp = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap; guesses, constraints)
212+
constraints = [y(2.) + x(8.) ~ 2.]
213+
bvp = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lotkavolterra, u0map, tspan, parammap; guesses, constraints)
249214
test_solvers(solvers, bvp, u0map, constraints)
250215

251-
constraints =* β - x(6.) ~ 24]
252-
bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, constraints)
216+
constraints =* β - x(6.) ~ 0.5]
217+
bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; guesses, constraints)
253218
test_solvers(solvers, bvp, u0map, constraints)
254219

255220
# Testing that errors are properly thrown when malformed constraints are given.
256221
@variables bad(..)
257222
constraints = [x(1.) + bad(3.) ~ 10]
258-
@test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, constraints)
223+
@test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; guesses, constraints)
259224

260225
constraints = [x(t) + y(t) ~ 3]
261-
@test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, constraints)
226+
@test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; guesses, constraints)
262227

263228
@parameters bad2
264229
constraints = [bad2 + x(0.) ~ 3]
265-
@test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, constraints)
230+
@test_throws Exception bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; guesses, constraints)
266231
end
267232

233+
# Cartesian pendulum from the docs.
234+
# DAE IVP solved using BoundaryValueDiffEq solvers.
235+
# let
236+
# @parameters g
237+
# @variables x(t) y(t) [state_priority = 10] λ(t)
238+
# eqs = [D(D(x)) ~ λ * x
239+
# D(D(y)) ~ λ * y - g
240+
# x^2 + y^2 ~ 1]
241+
# @mtkbuild pend = ODESystem(eqs, t)
242+
#
243+
# tspan = (0.0, 1.5)
244+
# u0map = [x => 1, y => 0]
245+
# pmap = [g => 1]
246+
# guess = [λ => 1]
247+
#
248+
# prob = ODEProblem(pend, u0map, tspan, pmap; guesses = guess)
249+
# osol = solve(prob, Rodas5P())
250+
#
251+
# zeta = [0., 0., 0., 0., 0.]
252+
# bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses = guess)
253+
#
254+
# for solver in solvers
255+
# sol = solve(bvp, solver(zeta), dt = 0.001)
256+
# @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
257+
# conditions = getfield.(equations(pend)[3:end], :rhs)
258+
# @test isapprox([sol[conditions][1]; sol[x][1] - 1; sol[y][1]], zeros(5), atol = 0.001)
259+
# end
260+
#
261+
# bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
262+
# for solver in solvers
263+
# sol = solve(bvp, solver(zeta), dt = 0.01)
264+
# @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
265+
# conditions = getfield.(equations(pend)[3:end], :rhs)
266+
# @test [sol[conditions][1]; sol[x][1] - 1; sol[y][1]] ≈ 0
267+
# end
268+
# end
269+
268270
# Adding a midpoint boundary constraint.
269271
# Solve using BVDAE solvers.
270272
# let

0 commit comments

Comments
 (0)