Skip to content

Commit ddae54c

Browse files
finish up dispatching-based versions
1 parent 1a48653 commit ddae54c

File tree

4 files changed

+25
-34
lines changed

4 files changed

+25
-34
lines changed

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,12 @@ function (f::ODEToExpr)(O::Operation)
152152
end
153153
(f::ODEToExpr)(x) = convert(Expr, x)
154154

155-
function generate_jacobian(sys::ODESystem; version::FunctionVersion = nothing)
156-
version != nothing && @warn("version is deprecated. Both dispatches are now constructed in the same function by defualt.")
155+
function generate_jacobian(sys::ODESystem, dvs = sys.dvs, ps = sys.ps)
157156
jac = calculate_jacobian(sys)
158-
return build_function(jac, sys.dvs, sys.ps, (sys.iv.name,), ODEToExpr(sys))
157+
return build_function(jac, dvs, ps, (sys.iv.name,), ODEToExpr(sys))
159158
end
160159

161-
function generate_function(sys::ODESystem, dvs, ps; version::FunctionVersion = nothing)
162-
version != nothing && @warn("version is deprecated. Both dispatches are now constructed in the same function by defualt.")
160+
function generate_function(sys::ODESystem, dvs = sys.dvs, ps = sys.ps)
163161
rhss = [deq.rhs for deq sys.eqs]
164162
dvs′ = [clean(dv) for dv dvs]
165163
ps′ = [clean(p) for p ps]
@@ -190,21 +188,14 @@ function calculate_factorized_W(sys::ODESystem, simplify=true)
190188
(Wfact,Wfact_t)
191189
end
192190

193-
function generate_factorized_W(sys::ODESystem, simplify=true; version::FunctionVersion = nothing)
194-
version != nothing && @warn("version is deprecated. Both dispatches are now constructed in the same function by defualt.")
191+
function generate_factorized_W(sys::ODESystem, vs = sys.dvs, ps = sys.ps, simplify=true)
195192
(Wfact,Wfact_t) = calculate_factorized_W(sys,simplify)
193+
siz = size(Wfact)
194+
constructor = :(x -> begin
195+
A = SMatrix{$siz...}(x)
196+
StaticArrays.LU(LowerTriangular( SMatrix{$siz...}(UnitLowerTriangular(A)) ), UpperTriangular(A), SVector(ntuple(n->n, max($siz...))))
197+
end)
196198

197-
if version === SArrayFunction
198-
siz = size(Wfact)
199-
constructor = :(x -> begin
200-
A = SMatrix{$siz...}(x)
201-
StaticArrays.LU(LowerTriangular( SMatrix{$siz...}(UnitLowerTriangular(A)) ), UpperTriangular(A), SVector(ntuple(n->n, max($siz...))))
202-
end)
203-
else
204-
constructor = nothing
205-
end
206-
207-
vs, ps = sys.dvs, sys.ps
208199
Wfact_func = build_function(Wfact , vs, ps, (:gam,:t), ODEToExpr(sys);constructor=constructor)
209200
Wfact_t_func = build_function(Wfact_t, vs, ps, (:gam,:t), ODEToExpr(sys);constructor=constructor)
210201

@@ -218,12 +209,12 @@ Create an `ODEFunction` from the [`ODESystem`](@ref). The arguments `dvs` and `p
218209
are used to set the order of the dependent variable and parameter vectors,
219210
respectively.
220211
"""
221-
function DiffEqBase.ODEFunction{iip}(sys::ODESystem, dvs, ps; version::FunctionVersion = nothing,
222-
jac = false, Wfact = false) where iip
223-
version != nothing && @warn("version is deprecated. Both dispatches are now constructed in the same function by defualt.")
212+
function DiffEqBase.ODEFunction{iip}(sys::ODESystem, dvs, ps;
213+
version = nothing,
214+
jac = false, Wfact = false) where iip
224215
expr = eval(generate_function(sys, dvs, ps))
225-
jac_expr = jac ? nothing : eval(generate_jacobian(sys))
226-
Wfact_expr,Wfact_t_expr = Wfact ? (nothing,nothing) : eval.(generate_factorized_W(sys))
216+
jac_expr = jac ? nothing : eval(generate_jacobian(sys, dvs, ps))
217+
Wfact_expr,Wfact_t_expr = Wfact ? (nothing,nothing) : eval.(generate_factorized_W(sys, dvs, ps))
227218
ODEFunction{iip}(eval(expr),jac=jac_expr,
228219
Wfact = Wfact_expr, Wfact_t = Wfact_t_expr)
229220
end

src/systems/nonlinear/nonlinear_system.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ function calculate_jacobian(sys::NonlinearSystem)
6666
return jac
6767
end
6868

69-
function generate_jacobian(sys::NonlinearSystem; version::FunctionVersion = ArrayFunction)
69+
function generate_jacobian(sys::NonlinearSystem)
7070
jac = calculate_jacobian(sys)
71-
return build_function(jac, clean.(sys.vs), sys.ps, (), NLSysToExpr(sys); version = version)
71+
return build_function(jac, clean.(sys.vs), sys.ps, (), NLSysToExpr(sys))
7272
end
7373

7474
struct NLSysToExpr
@@ -85,9 +85,9 @@ end
8585
(f::NLSysToExpr)(x) = convert(Expr, x)
8686

8787

88-
function generate_function(sys::NonlinearSystem, vs, ps; version::FunctionVersion = ArrayFunction)
88+
function generate_function(sys::NonlinearSystem, vs, ps; version = nothing)
8989
rhss = [eq.rhs for eq sys.eqs]
9090
vs′ = [clean(v) for v vs]
9191
ps′ = [clean(p) for p ps]
92-
return build_function(rhss, vs′, ps′, (), NLSysToExpr(sys); version = version)
92+
return build_function(rhss, vs′, ps′, (), NLSysToExpr(sys))
9393
end

src/utils.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@ function flatten_expr!(x)
3131
x
3232
end
3333

34-
function build_function(rhss, vs, ps, args = (), conv = rhs -> convert(Expr, rhs); version::FunctionVersion=nothing, constructor=nothing)
35-
version != nothing && @warn("version is deprecated. Both dispatches are now constructed in the same function by defualt.")
36-
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(vs)]
37-
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(ps)]
34+
function build_function(rhss, vs, ps, args = (), conv = rhs -> convert(Expr, rhs); constructor=nothing)
35+
_vs = map(x-> x isa Operation ? x.op : x, vs)
36+
_ps = map(x-> x isa Operation ? x.op : x, ps)
37+
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(_vs)]
38+
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(_ps)]
3839
(ls, rs) = zip(var_pairs..., param_pairs...)
3940

4041
var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, rs))

test/system_construction.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ eqs = [D(x) ~ σ*(y-x),
3232
de = ODESystem(eqs)
3333
test_diffeq_inference("standard", de, t, (x, y, z), (σ, ρ, β))
3434
generate_function(de, [x,y,z], [σ,ρ,β])
35-
generate_function(de, [x,y,z], [σ,ρ,β]; version=ModelingToolkit.SArrayFunction)
3635
jac_expr = generate_jacobian(de)
3736
jac = calculate_jacobian(de)
3837
jacfun = eval(jac_expr)
@@ -52,8 +51,8 @@ FWt = zeros(3, 3)
5251
fw(FW, u, p, 0.2, 0.1)
5352
fwt(FWt, u, p, 0.2, 0.1)
5453
# oop
55-
f = ODEFunction(de, [x,y,z], [σ,ρ,β]; version=ModelingToolkit.SArrayFunction)
56-
fw, fwt = map(eval, ModelingToolkit.generate_factorized_W(de; version=ModelingToolkit.SArrayFunction))
54+
f = ODEFunction(de, [x,y,z], [σ,ρ,β])
55+
fw, fwt = map(eval, ModelingToolkit.generate_factorized_W(de))
5756
du = @SArray zeros(3)
5857
u = SVector(1:3...)
5958
p = SVector(4:6...)

0 commit comments

Comments
 (0)