@@ -36,14 +36,23 @@ function DiffEqSystem(eqs, ivs;
36
36
DiffEqSystem (eqs, ivs, dvs, vs, ps, ivs[1 ]. subtype, dv_name, p_name, Matrix {Expression} (undef,0 ,0 ))
37
37
end
38
38
39
- function generate_ode_function (sys:: DiffEqSystem )
39
+ function generate_ode_function (sys:: DiffEqSystem ;version = ArrayFunction )
40
40
var_exprs = [:($ (sys. dvs[i]. name) = u[$ i]) for i in 1 : length (sys. dvs)]
41
41
param_exprs = [:($ (sys. ps[i]. name) = p[$ i]) for i in 1 : length (sys. ps)]
42
42
sys_exprs = build_equals_expr .(sys. eqs)
43
- dvar_exprs = [:(du[$ i] = $ (Symbol (" $(sys. dvs[i]. name) _$(sys. ivs[1 ]. name) " ))) for i in 1 : length (sys. dvs)]
44
- exprs = vcat (var_exprs,param_exprs,sys_exprs,dvar_exprs)
45
- block = expr_arr_to_block (exprs)
46
- :((du,u,p,t)-> $ (block))
43
+
44
+ if version == ArrayFunction
45
+ dvar_exprs = [:(du[$ i] = $ (Symbol (" $(sys. dvs[i]. name) _$(sys. ivs[1 ]. name) " ))) for i in 1 : length (sys. dvs)]
46
+ exprs = vcat (var_exprs,param_exprs,sys_exprs,dvar_exprs)
47
+ block = expr_arr_to_block (exprs)
48
+ :((du,u,p,t)-> $ (block))
49
+ elseif version == SArrayFunction
50
+ dvar_exprs = [:($ (Symbol (" $(sys. dvs[i]. name) _$(sys. ivs[1 ]. name) " ))) for i in 1 : length (sys. dvs)]
51
+ svector_expr = :(typeof (u)($ (dvar_exprs... )))
52
+ exprs = vcat (var_exprs,param_exprs,sys_exprs,svector_expr)
53
+ block = expr_arr_to_block (exprs)
54
+ :((u,p,t)-> $ (block))
55
+ end
47
56
end
48
57
49
58
isintermediate (eq) = eq. args[1 ]. diff == nothing
0 commit comments