Skip to content

Commit 948b5d2

Browse files
keep expression output
1 parent 2fad705 commit 948b5d2

File tree

2 files changed

+29
-12
lines changed

2 files changed

+29
-12
lines changed

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -160,16 +160,16 @@ function (f::ODEToExpr)(O::Operation)
160160
end
161161
(f::ODEToExpr)(x) = convert(Expr, x)
162162

163-
function generate_jacobian(sys::ODESystem, dvs = sys.dvs, ps = sys.ps)
163+
function generate_jacobian(sys::ODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true})
164164
jac = calculate_jacobian(sys)
165-
return build_function(jac, dvs, ps, (sys.iv.name,), ODEToExpr(sys))
165+
return build_function(jac, dvs, ps, (sys.iv.name,), ODEToExpr(sys), expression)
166166
end
167167

168-
function generate_function(sys::ODESystem, dvs = sys.dvs, ps = sys.ps)
168+
function generate_function(sys::ODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true})
169169
rhss = [deq.rhs for deq sys.eqs]
170170
dvs′ = [clean(dv) for dv dvs]
171171
ps′ = [clean(p) for p ps]
172-
return build_function(rhss, dvs′, ps′, (sys.iv.name,), ODEToExpr(sys))
172+
return build_function(rhss, dvs′, ps′, (sys.iv.name,), ODEToExpr(sys), expression)
173173
end
174174

175175
function calculate_factorized_W(sys::ODESystem, simplify=true)
@@ -196,16 +196,16 @@ function calculate_factorized_W(sys::ODESystem, simplify=true)
196196
(Wfact,Wfact_t)
197197
end
198198

199-
function generate_factorized_W(sys::ODESystem, vs = sys.dvs, ps = sys.ps, simplify=true)
199+
function generate_factorized_W(sys::ODESystem, vs = sys.dvs, ps = sys.ps, simplify=true, expression = Val{true})
200200
(Wfact,Wfact_t) = calculate_factorized_W(sys,simplify)
201201
siz = size(Wfact)
202202
constructor = :(x -> begin
203203
A = SMatrix{$siz...}(x)
204204
StaticArrays.LU(LowerTriangular( SMatrix{$siz...}(UnitLowerTriangular(A)) ), UpperTriangular(A), SVector(ntuple(n->n, max($siz...))))
205205
end)
206206

207-
Wfact_func = build_function(Wfact , vs, ps, (:gam,:t), ODEToExpr(sys);constructor=constructor)
208-
Wfact_t_func = build_function(Wfact_t, vs, ps, (:gam,:t), ODEToExpr(sys);constructor=constructor)
207+
Wfact_func = build_function(Wfact , vs, ps, (:gam,:t), ODEToExpr(sys), expression;constructor=constructor)
208+
Wfact_t_func = build_function(Wfact_t, vs, ps, (:gam,:t), ODEToExpr(sys), expression;constructor=constructor)
209209

210210
return (Wfact_func, Wfact_t_func)
211211
end
@@ -220,21 +220,21 @@ respectively.
220220
function DiffEqBase.ODEFunction{iip}(sys::ODESystem, dvs, ps;
221221
version = nothing,
222222
jac = false, Wfact = false) where {iip}
223-
f_oop,f_iip = generate_function(sys, dvs, ps)
223+
f_oop,f_iip = generate_function(sys, dvs, ps, Val{false})
224224

225225
f(u,p,t) = f_oop(u,p,t)
226226
f(du,u,p,t) = f_iip(du,u,p,t)
227227

228228
if jac
229-
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps)
229+
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps, Val{false})
230230
_jac(u,p,t) = jac_oop(u,p,t)
231231
_jac(J,u,p,t) = jac_iip(J,u,p,t)
232232
else
233233
_jac = nothing
234234
end
235235

236236
if Wfact
237-
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps)
237+
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, Val{false})
238238
Wfact_oop, Wfact_iip = tmp_Wfact
239239
Wfact_oop_t, Wfact_iip_t = tmp_Wfact_t
240240
_Wfact(u,p,t) = Wfact_oop(u,p,t)

src/utils.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ mk_function(args, kwargs, body) =
3838
GG.RuntimeFn{Args, Kwargs, Body}()
3939
end
4040

41-
function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr; constructor=nothing)
41+
function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr, expression = Val{true}; constructor=nothing)
4242
_vs = map(x-> x isa Operation ? x.op : x, vs)
4343
_ps = map(x-> x isa Operation ? x.op : x, ps)
4444
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(_vs)]
@@ -57,7 +57,24 @@ function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr; co
5757
let_expr = Expr(:let, var_eqs, sys_expr)
5858

5959
fargs = ps == () ? :(u,$(args...)) : :(u,p,$(args...))
60-
mk_function(fargs,:(),let_expr), mk_function(:($X,$(fargs.args...)),:(),ip_let_expr)
60+
61+
if expression == Val{true}
62+
return quote
63+
function $fname($X,$(fargs.args...))
64+
$ip_let_expr
65+
nothing
66+
end
67+
function $fname($(fargs.args...))
68+
X = $let_expr
69+
T = promote_type(map(typeof,X)...)
70+
convert.(T,X)
71+
construct = $(constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.similar_type(typeof(u), eltype(X)) : x->(du=similar(u, T, $(size(rhss)...)); vec(du) .= x; du)) : constructor)
72+
construct(X)
73+
end
74+
end
75+
else
76+
return mk_function(fargs,:(),let_expr), mk_function(:($X,$(fargs.args...)),:(),ip_let_expr)
77+
end
6178
end
6279

6380
is_constant(::Constant) = true

0 commit comments

Comments
 (0)