Skip to content

Commit ca8c538

Browse files
allow construction of SArray derivative function
1 parent 082e263 commit ca8c538

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ Base.convert(::Type{Variable},x::Int64) = Constant(x)
2323

2424
function caclulate_jacobian end
2525

26+
@enum FunctionVersions ArrayFunction=1 SArrayFunction=2
27+
2628
include("operations.jl")
2729
include("operators.jl")
2830
include("systems/diffeqs/diffeqsystem.jl")

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,23 @@ function DiffEqSystem(eqs, ivs;
3636
DiffEqSystem(eqs, ivs, dvs, vs, ps, ivs[1].subtype, dv_name, p_name, Matrix{Expression}(undef,0,0))
3737
end
3838

39-
function generate_ode_function(sys::DiffEqSystem)
39+
function generate_ode_function(sys::DiffEqSystem;version = ArrayFunction)
4040
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in 1:length(sys.dvs)]
4141
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
4242
sys_exprs = build_equals_expr.(sys.eqs)
43-
dvar_exprs = [:(du[$i] = $(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in 1:length(sys.dvs)]
44-
exprs = vcat(var_exprs,param_exprs,sys_exprs,dvar_exprs)
45-
block = expr_arr_to_block(exprs)
46-
:((du,u,p,t)->$(block))
43+
44+
if version == ArrayFunction
45+
dvar_exprs = [:(du[$i] = $(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in 1:length(sys.dvs)]
46+
exprs = vcat(var_exprs,param_exprs,sys_exprs,dvar_exprs)
47+
block = expr_arr_to_block(exprs)
48+
:((du,u,p,t)->$(block))
49+
elseif version == SArrayFunction
50+
dvar_exprs = [:($(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in 1:length(sys.dvs)]
51+
svector_expr = :(typeof(u)($(dvar_exprs...)))
52+
exprs = vcat(var_exprs,param_exprs,sys_exprs,svector_expr)
53+
block = expr_arr_to_block(exprs)
54+
:((u,p,t)->$(block))
55+
end
4756
end
4857

4958
isintermediate(eq) = eq.args[1].diff == nothing

test/system_construction.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ eqs = [D*x ~ σ*(y-x),
1515
D*z ~ x*y - β*z]
1616
de = DiffEqSystem(eqs,[t],[x,y,z],Variable[],[σ,ρ,β])
1717
ModelingToolkit.generate_ode_function(de)
18+
ModelingToolkit.generate_ode_function(de;version=ModelingToolkit.SArrayFunction)
1819
jac_expr = ModelingToolkit.generate_ode_jacobian(de)
1920
jac = ModelingToolkit.calculate_jacobian(de)
2021
f = ODEFunction(de)

0 commit comments

Comments
 (0)