Skip to content

Commit c9f47a0

Browse files
authored
Define inplace and out of place methods in ODEFunctionExpr
In this PR, I attempt to define an inplace and out of place methods for the ODE function in `ODEFunctionExpr`. I did this on GitHub for some reason so I didn't run local tests yet. Will add a test after I open the PR.
1 parent 10e4743 commit c9f47a0

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -192,22 +192,34 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
192192
sparse = false, simplify=false,
193193
kwargs...) where {iip}
194194

195-
idx = iip ? 2 : 1
196-
f = generate_function(sys, dvs, ps; expression=Val{true}, kwargs...)[idx]
195+
f_oop, f_iip = generate_function(sys, dvs, ps; expression=Val{true}, kwargs...)
196+
_f = quote
197+
f(u,p,t) = $f_oop(u,p,t)
198+
f(du,u,p,t) = $f_iip(du,u,p,t)
199+
end
200+
197201
if tgrad
198-
_tgrad = generate_tgrad(sys, dvs, ps;
202+
tgrad_oop, tgrad_iip = generate_tgrad(sys, dvs, ps;
199203
simplify=simplify,
200-
expression=Val{true}, kwargs...)[idx]
204+
expression=Val{true}, kwargs...)
205+
_tgrad = quote
206+
tgrad(u,p,t) = $tgrad_oop(u,p,t)
207+
tgrad(J,u,p,t) = $tgrad_iip(J,u,p,t)
208+
end
201209
else
202-
_tgrad = :nothing
210+
_tgrad = :(tgrad = nothing)
203211
end
204212

205213
if jac
206-
_jac = generate_jacobian(sys, dvs, ps;
214+
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps;
207215
sparse=sparse, simplify=simplify,
208-
expression=Val{true}, kwargs...)[idx]
216+
expression=Val{true}, kwargs...)
217+
_jac = quote
218+
jac(u,p,t) = $jac_oop(u,p,t)
219+
jac(J,u,p,t) = $jac_iip(J,u,p,t)
220+
end
209221
else
210-
_jac = :nothing
222+
_jac = :(jac = nothing)
211223
end
212224

213225
M = calculate_massmatrix(sys)
@@ -217,9 +229,9 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
217229
jp_expr = sparse ? :(similar($(get_jac(sys)[]),Float64)) : :nothing
218230

219231
ex = quote
220-
f = $f
221-
tgrad = $_tgrad
222-
jac = $_jac
232+
$f
233+
$_tgrad
234+
$_jac
223235
M = $_M
224236
ODEFunction{$iip}(
225237
f,

0 commit comments

Comments
 (0)