@@ -36,14 +36,22 @@ function DiffEqSystem(eqs, ivs;
36
36
DiffEqSystem (eqs, ivs, dvs, vs, ps, ivs[1 ]. subtype, dv_name, p_name, Matrix {Expression} (undef,0 ,0 ))
37
37
end
38
38
39
- function generate_ode_function (sys:: DiffEqSystem )
39
+ function generate_ode_function (sys:: DiffEqSystem ;version = ArrayFunction )
40
40
var_exprs = [:($ (sys. dvs[i]. name) = u[$ i]) for i in 1 : length (sys. dvs)]
41
41
param_exprs = [:($ (sys. ps[i]. name) = p[$ i]) for i in 1 : length (sys. ps)]
42
42
sys_exprs = build_equals_expr .(sys. eqs)
43
- dvar_exprs = [:(du[$ i] = $ (Symbol (" $(sys. dvs[i]. name) _$(sys. ivs[1 ]. name) " ))) for i in 1 : length (sys. dvs)]
44
- exprs = vcat (var_exprs,param_exprs,sys_exprs,dvar_exprs)
45
- block = expr_arr_to_block (exprs)
46
- :((du,u,p,t)-> $ (toexpr (block)))
43
+ if version == ArrayFunction
44
+ dvar_exprs = [:(du[$ i] = $ (Symbol (" $(sys. dvs[i]. name) _$(sys. ivs[1 ]. name) " ))) for i in 1 : length (sys. dvs)]
45
+ exprs = vcat (var_exprs,param_exprs,sys_exprs,dvar_exprs)
46
+ block = expr_arr_to_block (exprs)
47
+ :((du,u,p,t)-> $ (toexpr (block)))
48
+ elseif version == SArrayFunction
49
+ dvar_exprs = [:($ (Symbol (" $(sys. dvs[i]. name) _$(sys. ivs[1 ]. name) " ))) for i in 1 : length (sys. dvs)]
50
+ svector_expr = :(typeof (u)($ (dvar_exprs... )))
51
+ exprs = vcat (var_exprs,param_exprs,sys_exprs,svector_expr)
52
+ block = expr_arr_to_block (exprs)
53
+ :((u,p,t)-> $ (toexpr (block)))
54
+ end
47
55
end
48
56
49
57
isintermediate (eq) = eq. args[1 ]. diff == nothing
@@ -123,9 +131,13 @@ function generate_ode_iW(sys::DiffEqSystem,simplify=true)
123
131
:((iW,u,p,gam,t)-> $ (block)),:((iW,u,p,gam,t)-> $ (block2))
124
132
end
125
133
126
- function DiffEqBase. ODEFunction (sys:: DiffEqSystem )
127
- expr = generate_ode_function (sys)
128
- ODEFunction {true} (eval (expr))
134
+ function DiffEqBase. ODEFunction (sys:: DiffEqSystem ;version = ArrayFunction,kwargs... )
135
+ expr = generate_ode_function (sys;kwargs... )
136
+ if version == ArrayFunction
137
+ ODEFunction {true} (eval (expr))
138
+ elseif version == SArrayFunction
139
+ ODEFunction {false} (eval (expr))
140
+ end
129
141
end
130
142
131
143
export DiffEqSystem, ODEFunction
0 commit comments