Skip to content

Commit 4a8a86e

Browse files
Refactor function building
1 parent 40b0370 commit 4a8a86e

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

src/systems/systems.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,28 @@ end
2121
function generate_function(sys::AbstractSystem; version::FunctionVersion = ArrayFunction)
2222
sys_eqs = system_eqs(sys)
2323
vs, ps = system_vars(sys), system_params(sys)
24+
return build_function([eq.rhs for eq sys_eqs], vs, ps; version = version)
25+
end
2426

27+
function build_function(rhss, vs, ps; version::FunctionVersion)
2528
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(vs)]
2629
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(ps)]
2730
(ls, rs) = collect(zip(var_pairs..., param_pairs...))
2831

2932
var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, rs))
30-
sys_exprs = build_expr(:tuple, [convert(Expr, eq.rhs) for eq sys_eqs])
31-
let_expr = Expr(:let, var_eqs, sys_exprs)
3233

3334
if version === ArrayFunction
34-
:((du,u,p,t) -> du .= $let_expr)
35+
X = gensym()
36+
sys_exprs = [:($X[$i] = $(convert(Expr, rhs))) for (i, rhs) enumerate(rhss)]
37+
let_expr = Expr(:let, var_eqs, build_expr(:block, sys_exprs))
38+
:(($X,u,p,t) -> $let_expr)
3539
elseif version === SArrayFunction
40+
sys_expr = build_expr(:tuple, [convert(Expr, rhs) for rhs rhss])
41+
let_expr = Expr(:let, var_eqs, sys_expr)
3642
:((u,p,t) -> begin
37-
du = $let_expr
38-
T = StaticArrays.similar_type(typeof(u), eltype(du))
39-
T(du)
43+
X = $let_expr
44+
T = StaticArrays.similar_type(typeof(u), eltype(X))
45+
T(X)
4046
end)
4147
end
4248
end

0 commit comments

Comments
 (0)