Skip to content

Commit 13a242c

Browse files
committed
update to use updated codegen
1 parent 8ae2803 commit 13a242c

File tree

2 files changed

+17
-27
lines changed

2 files changed

+17
-27
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -881,18 +881,23 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
881881
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
882882
u0_idxs = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k,v) in u0map]
883883

884-
bc = generate_function_bc(sys, u0, u0_idxs, tspan, iip)
884+
fns = generate_function_bc(sys, u0, u0_idxs, tspan)
885+
bc_oop, bc_iip = eval_or_rgf.(fns; eval_expression, eval_module)
886+
# bc(sol, p, t) = bc_oop(sol, p, t)
887+
bc(resid, u, p, t) = bc_iip(resid, u, p, t)
888+
885889
return BVProblem{iip}(f, bc, u0, tspan, p; kwargs...)
886890
end
887891

888892
get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")
889893

890894
"""
891-
generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan, iip)
895+
generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan)
892896
893897
Given an ODESystem with constraints, generate the boundary condition function to pass to boundary value problem solvers.
898+
Expression uses the constraints and the provided initial conditions.
894899
"""
895-
function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan, iip)
900+
function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan; kwargs...)
896901
iv = get_iv(sys)
897902
sts = get_unknowns(sys)
898903
ps = get_ps(sys)
@@ -915,19 +920,6 @@ function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan, iip)
915920

916921
cons = map(c -> Symbolics.substitute(c, Dict(x(t) => sol(t)[idx])), cons)
917922
end
918-
919-
for var in parameters(conssys)
920-
if iscall(var)
921-
x = operation(var)
922-
t = only(arguments(var))
923-
idx = pidxmap[x]
924-
925-
cons = map(c -> Symbolics.substitute(c, Dict(x(t) => p[idx])), cons)
926-
else
927-
idx = pidxmap[var]
928-
cons = map(c -> Symbolics.substitute(c, Dict(var => p[idx])), cons)
929-
end
930-
end
931923
end
932924

933925
init_conds = Any[]
@@ -937,12 +929,9 @@ function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan, iip)
937929
end
938930

939931
exprs = vcat(init_conds, cons)
940-
bcs = Symbolics.build_function(exprs, sol, p, expression = Val{false})
941-
if iip
942-
return (resid, u, p, t) -> bcs[2](resid, u, p)
943-
else
944-
return (u, p, t) -> bcs[1](u, p)
945-
end
932+
_p = reorder_parameters(sys, ps)
933+
934+
build_function_wrapper(sys, exprs, sol, _p..., t; kwargs...)
946935
end
947936

948937
"""

test/bvproblem.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
### TODO: update when BoundaryValueDiffEqAscher is updated to use the normal boundary condition conventions
22

3-
using BoundaryValueDiffEq, OrdinaryDiffEq, BoundaryValueDiffEqAscher
3+
using OrdinaryDiffEqVerner
4+
using BoundaryValueDiffEqMIRK, BoundaryValueDiffEqAscher
45
using BenchmarkTools
56
using ModelingToolkit
67
using SciMLBase
@@ -207,22 +208,22 @@ let
207208

208209
u0map = []
209210
tspan = (0.0, 1.0)
210-
guesses = [x(t) => 4.0, y(t) => 2.]
211+
guess = [x(t) => 4.0, y(t) => 2.0]
211212
constr = [x(.6) ~ 3.5, x(.3) ~ 7.]
212213
@mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
213214

214-
bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses)
215+
bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses = guess)
215216
test_solvers(solvers, bvp, u0map, constr; dt = 0.05)
216217

217218
# Testing that more complicated constraints give correct solutions.
218219
constr = [y(.2) + x(.8) ~ 3., y(.3) ~ 2.]
219220
@mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
220-
bvp = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, u0map, tspan; guesses)
221+
bvp = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, u0map, tspan; guesses = guess)
221222
test_solvers(solvers, bvp, u0map, constr; dt = 0.05)
222223

223224
constr =* β - x(.6) ~ 0.0, y(.2) ~ 3.]
224225
@mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
225-
bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses)
226+
bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses = guess)
226227
test_solvers(solvers, bvp, u0map, constr)
227228
end
228229

0 commit comments

Comments
 (0)