Skip to content

Commit 1a48653

Browse files
Use dispatch for in-place out-of-place function versions
Since anonymous functions cannot dispatch, this creates a gensym'd generic function with the two dispatches, making it simultaniously holding the oop and iip versions. Then ODEFunction just specifies which path to take. Defaults to saying it should be ran in-place.
1 parent 8556985 commit 1a48653

File tree

2 files changed

+36
-29
lines changed

2 files changed

+36
-29
lines changed

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -152,16 +152,18 @@ function (f::ODEToExpr)(O::Operation)
152152
end
153153
(f::ODEToExpr)(x) = convert(Expr, x)
154154

155-
function generate_jacobian(sys::ODESystem; version::FunctionVersion = ArrayFunction)
155+
function generate_jacobian(sys::ODESystem; version::FunctionVersion = nothing)
156+
version != nothing && @warn("version is deprecated. Both dispatches are now constructed in the same function by defualt.")
156157
jac = calculate_jacobian(sys)
157-
return build_function(jac, sys.dvs, sys.ps, (sys.iv.name,), ODEToExpr(sys); version = version)
158+
return build_function(jac, sys.dvs, sys.ps, (sys.iv.name,), ODEToExpr(sys))
158159
end
159160

160-
function generate_function(sys::ODESystem, dvs, ps; version::FunctionVersion = ArrayFunction)
161+
function generate_function(sys::ODESystem, dvs, ps; version::FunctionVersion = nothing)
162+
version != nothing && @warn("version is deprecated. Both dispatches are now constructed in the same function by defualt.")
161163
rhss = [deq.rhs for deq sys.eqs]
162164
dvs′ = [clean(dv) for dv dvs]
163165
ps′ = [clean(p) for p ps]
164-
return build_function(rhss, dvs′, ps′, (sys.iv.name,), ODEToExpr(sys); version = version)
166+
return build_function(rhss, dvs′, ps′, (sys.iv.name,), ODEToExpr(sys))
165167
end
166168

167169
function calculate_factorized_W(sys::ODESystem, simplify=true)
@@ -188,7 +190,8 @@ function calculate_factorized_W(sys::ODESystem, simplify=true)
188190
(Wfact,Wfact_t)
189191
end
190192

191-
function generate_factorized_W(sys::ODESystem, simplify=true; version::FunctionVersion = ArrayFunction)
193+
function generate_factorized_W(sys::ODESystem, simplify=true; version::FunctionVersion = nothing)
194+
version != nothing && @warn("version is deprecated. Both dispatches are now constructed in the same function by defualt.")
192195
(Wfact,Wfact_t) = calculate_factorized_W(sys,simplify)
193196

194197
if version === SArrayFunction
@@ -202,8 +205,8 @@ function generate_factorized_W(sys::ODESystem, simplify=true; version::FunctionV
202205
end
203206

204207
vs, ps = sys.dvs, sys.ps
205-
Wfact_func = build_function(Wfact , vs, ps, (:gam,:t), ODEToExpr(sys); version = version, constructor=constructor)
206-
Wfact_t_func = build_function(Wfact_t, vs, ps, (:gam,:t), ODEToExpr(sys); version = version, constructor=constructor)
208+
Wfact_func = build_function(Wfact , vs, ps, (:gam,:t), ODEToExpr(sys);constructor=constructor)
209+
Wfact_t_func = build_function(Wfact_t, vs, ps, (:gam,:t), ODEToExpr(sys);constructor=constructor)
207210

208211
return (Wfact_func, Wfact_t_func)
209212
end
@@ -215,16 +218,15 @@ Create an `ODEFunction` from the [`ODESystem`](@ref). The arguments `dvs` and `p
215218
are used to set the order of the dependent variable and parameter vectors,
216219
respectively.
217220
"""
218-
function DiffEqBase.ODEFunction(sys::ODESystem, dvs, ps; version::FunctionVersion = ArrayFunction,
219-
jac = false, Wfact = false)
220-
expr = eval(generate_function(sys, dvs, ps; version = version))
221+
function DiffEqBase.ODEFunction{iip}(sys::ODESystem, dvs, ps; version::FunctionVersion = nothing,
222+
jac = false, Wfact = false) where iip
223+
version != nothing && @warn("version is deprecated. Both dispatches are now constructed in the same function by defualt.")
224+
expr = eval(generate_function(sys, dvs, ps))
221225
jac_expr = jac ? nothing : eval(generate_jacobian(sys))
222-
Wfact_expr,Wfact_t_expr = Wfact ? (nothing,nothing) : eval.(calculate_factorized_W(sys))
223-
if version === ArrayFunction
224-
ODEFunction{true}(eval(expr),jac=jac_expr,
225-
Wfact = Wfact_expr, Wfact_t = Wfact_t_expr)
226-
elseif version === SArrayFunction
227-
ODEFunction{false}(eval(expr),jac=jac_expr,
228-
Wfact = Wfact_expr, Wfact_t = Wfact_t_expr)
229-
end
226+
Wfact_expr,Wfact_t_expr = Wfact ? (nothing,nothing) : eval.(generate_factorized_W(sys))
227+
ODEFunction{iip}(eval(expr),jac=jac_expr,
228+
Wfact = Wfact_expr, Wfact_t = Wfact_t_expr)
229+
end
230+
function DiffEqBase.ODEFunction(sys::ODESystem, args...; kwargs...)
231+
ODEFunction{true}(sys, args...; kwargs...)
230232
end

src/utils.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,31 @@ function flatten_expr!(x)
3131
x
3232
end
3333

34-
function build_function(rhss, vs, ps, args = (), conv = rhs -> convert(Expr, rhs); version::FunctionVersion, constructor=nothing)
34+
function build_function(rhss, vs, ps, args = (), conv = rhs -> convert(Expr, rhs); version::FunctionVersion=nothing, constructor=nothing)
35+
version != nothing && @warn("version is deprecated. Both dispatches are now constructed in the same function by defualt.")
3536
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(vs)]
3637
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(ps)]
3738
(ls, rs) = zip(var_pairs..., param_pairs...)
3839

3940
var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, rs))
4041

41-
if version === ArrayFunction
42-
X = gensym()
43-
sys_exprs = [:($X[$i] = $(conv(rhs))) for (i, rhs) enumerate(rhss)]
44-
let_expr = Expr(:let, var_eqs, build_expr(:block, sys_exprs))
45-
:(($X,u,p,$(args...)) -> $let_expr)
46-
elseif version === SArrayFunction
47-
sys_expr = build_expr(:tuple, [conv(rhs) for rhs rhss])
48-
let_expr = Expr(:let, var_eqs, sys_expr)
49-
:((u,p,$(args...)) -> begin
42+
fname = gensym()
43+
44+
X = gensym()
45+
ip_sys_exprs = [:($X[$i] = $(conv(rhs))) for (i, rhs) enumerate(rhss)]
46+
ip_let_expr = Expr(:let, var_eqs, build_expr(:block, ip_sys_exprs))
47+
48+
sys_expr = build_expr(:tuple, [conv(rhs) for rhs rhss])
49+
let_expr = Expr(:let, var_eqs, sys_expr)
50+
quote
51+
function $fname($X,u,p,$(args...))
52+
$ip_let_expr
53+
end
54+
function $fname(u,p,$(args...))
5055
X = $let_expr
5156
T = $(constructor === nothing ? :(StaticArrays.similar_type(typeof(u), eltype(X))) : constructor)
5257
T(X)
53-
end)
58+
end
5459
end
5560
end
5661

0 commit comments

Comments
 (0)