Skip to content

Commit 01ce2b3

Browse files
Merge pull request #804 from mohamed82008/patch-1
Define inplace and out of place methods in ODEFunctionExpr
2 parents 965f6a9 + d51a0b9 commit 01ce2b3

File tree

2 files changed

+63
-30
lines changed

2 files changed

+63
-30
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -192,39 +192,53 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
192192
sparse = false, simplify=false,
193193
kwargs...) where {iip}
194194

195-
idx = iip ? 2 : 1
196-
f = generate_function(sys, dvs, ps; expression=Val{true}, kwargs...)[idx]
195+
f_oop, f_iip = generate_function(sys, dvs, ps; expression=Val{true}, kwargs...)
196+
fsym = gensym(:f)
197+
_f = quote
198+
$fsym(u,p,t) = $f_oop(u,p,t)
199+
$fsym(du,u,p,t) = $f_iip(du,u,p,t)
200+
end
201+
202+
tgradsym = gensym(:tgrad)
197203
if tgrad
198-
_tgrad = generate_tgrad(sys, dvs, ps;
204+
tgrad_oop, tgrad_iip = generate_tgrad(sys, dvs, ps;
199205
simplify=simplify,
200-
expression=Val{true}, kwargs...)[idx]
206+
expression=Val{true}, kwargs...)
207+
_tgrad = quote
208+
$tgradsym(u,p,t) = $tgrad_oop(u,p,t)
209+
$tgradsym(J,u,p,t) = $tgrad_iip(J,u,p,t)
210+
end
201211
else
202-
_tgrad = :nothing
212+
_tgrad = :($tgradsym = nothing)
203213
end
204214

215+
jacsym = gensym(:jac)
205216
if jac
206-
_jac = generate_jacobian(sys, dvs, ps;
217+
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps;
207218
sparse=sparse, simplify=simplify,
208-
expression=Val{true}, kwargs...)[idx]
219+
expression=Val{true}, kwargs...)
220+
_jac = quote
221+
$jacsym(u,p,t) = $jac_oop(u,p,t)
222+
$jacsym(J,u,p,t) = $jac_iip(J,u,p,t)
223+
end
209224
else
210-
_jac = :nothing
225+
_jac = :($jacsym = nothing)
211226
end
212227

213228
M = calculate_massmatrix(sys)
214229

215230
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0',M)
216231

217232
jp_expr = sparse ? :(similar($(get_jac(sys)[]),Float64)) : :nothing
218-
219233
ex = quote
220-
f = $f
221-
tgrad = $_tgrad
222-
jac = $_jac
234+
$_f
235+
$_tgrad
236+
$_jac
223237
M = $_M
224238
ODEFunction{$iip}(
225-
f,
226-
jac = jac,
227-
tgrad = tgrad,
239+
$fsym,
240+
jac = $jacsym,
241+
tgrad = $tgradsym,
228242
mass_matrix = M,
229243
jac_prototype = $jp_expr,
230244
syms = $(Symbol.(states(sys))),

test/odesystem.jl

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,38 @@ generate_function(de, [x,y,z], [σ,ρ,β])
3535
jac_expr = generate_jacobian(de)
3636
jac = calculate_jacobian(de)
3737
jacfun = eval(jac_expr[2])
38-
# iip
39-
f = ODEFunction(de, [x,y,z], [σ,ρ,β])
40-
du = zeros(3)
41-
u = collect(1:3)
42-
p = collect(4:6)
43-
f(du, u, p, 0.1)
44-
@test du == [4, 0, -16]
45-
J = zeros(3, 3)
46-
jacfun(J, u, p, t)
47-
# oop
48-
f = ODEFunction(de, [x,y,z], [σ,ρ,β])
49-
du = @SArray zeros(3)
50-
u = SVector(1:3...)
51-
p = SVector(4:6...)
52-
@test f(u, p, 0.1) === @SArray [4, 0, -16]
38+
39+
for f in [
40+
ODEFunction(de, [x,y,z], [σ,ρ,β], tgrad = true, jac = true),
41+
eval(ODEFunctionExpr(de, [x,y,z], [σ,ρ,β], tgrad = true, jac = true)),
42+
]
43+
# iip
44+
du = zeros(3)
45+
u = collect(1:3)
46+
p = collect(4:6)
47+
f.f(du, u, p, 0.1)
48+
@test du == [4, 0, -16]
49+
50+
# oop
51+
du = @SArray zeros(3)
52+
u = SVector(1:3...)
53+
p = SVector(4:6...)
54+
@test f.f(u, p, 0.1) === @SArray [4, 0, -16]
55+
56+
# iip vs oop
57+
du = zeros(3)
58+
g = similar(du)
59+
J = zeros(3, 3)
60+
u = collect(1:3)
61+
p = collect(4:6)
62+
f.f(du, u, p, 0.1)
63+
@test du == f(u, p, 0.1)
64+
f.tgrad(g, u, p, t)
65+
@test g == f.tgrad(u, p, t)
66+
f.jac(J, u, p, t)
67+
@test J == f.jac(u, p, t)
68+
end
69+
5370

5471
eqs = [D(x) ~ σ*(y-x),
5572
D(y) ~ x*-z)-y*t,
@@ -59,6 +76,8 @@ ModelingToolkit.calculate_tgrad(de)
5976

6077
tgrad_oop, tgrad_iip = eval.(ModelingToolkit.generate_tgrad(de))
6178

79+
u = SVector(1:3...)
80+
p = SVector(4:6...)
6281
@test tgrad_oop(u,p,t) == [0.0,-u[2],0.0]
6382
du = zeros(3)
6483
tgrad_iip(du,u,p,t)

0 commit comments

Comments
 (0)