Skip to content

Commit cdd2843

Browse files
iW generation
1 parent 84b1079 commit cdd2843

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
struct DiffEqSystem <: AbstractSystem
1+
mutable struct DiffEqSystem <: AbstractSystem
22
eqs::Vector{Operation}
33
ivs::Vector{Variable}
44
dvs::Vector{Variable}
@@ -7,13 +7,14 @@ struct DiffEqSystem <: AbstractSystem
77
iv_name::Symbol
88
dv_name::Symbol
99
p_name::Symbol
10+
jac::Matrix{Expression}
1011
end
1112

1213
function DiffEqSystem(eqs, ivs, dvs, vs, ps)
1314
iv_name = ivs[1].subtype
1415
dv_name = dvs[1].subtype
1516
p_name = isempty(ps) ? :Parameter : ps[1].subtype
16-
DiffEqSystem(eqs, ivs, dvs, vs, ps, iv_name, dv_name, p_name)
17+
DiffEqSystem(eqs, ivs, dvs, vs, ps, iv_name, dv_name, p_name, Matrix{Expression}(0,0))
1718
end
1819

1920
function DiffEqSystem(eqs; iv_name = :IndependentVariable,
@@ -23,7 +24,7 @@ function DiffEqSystem(eqs; iv_name = :IndependentVariable,
2324
targetmap = Dict(iv_name => iv_name, dv_name => dv_name, v_name => v_name,
2425
p_name => p_name)
2526
ivs, dvs, vs, ps = extract_elements(eqs, targetmap)
26-
DiffEqSystem(eqs, ivs, dvs, vs, ps, iv_name, dv_name, p_name)
27+
DiffEqSystem(eqs, ivs, dvs, vs, ps, iv_name, dv_name, p_name, Matrix{Expression}(0,0))
2728
end
2829

2930
function DiffEqSystem(eqs, ivs;
@@ -32,7 +33,7 @@ function DiffEqSystem(eqs, ivs;
3233
p_name = :Parameter)
3334
targetmap = Dict(dv_name => dv_name, v_name => v_name, p_name => p_name)
3435
dvs, vs, ps = extract_elements(eqs, targetmap)
35-
DiffEqSystem(eqs, ivs, dvs, vs, ps, ivs[1].subtype, dv_name, p_name)
36+
DiffEqSystem(eqs, ivs, dvs, vs, ps, ivs[1].subtype, dv_name, p_name, Matrix{Expression}(0,0))
3637
end
3738

3839
function generate_ode_function(sys::DiffEqSystem)
@@ -81,12 +82,42 @@ function generate_ode_jacobian(sys::DiffEqSystem,simplify=true)
8182
diff_idxs = map(eq->eq.args[1].diff !=nothing,sys.eqs)
8283
diff_exprs = sys.eqs[diff_idxs]
8384
jac = calculate_jacobian(sys,simplify)
85+
sys.jac = jac
8486
jac_exprs = [:(J[$i,$j] = $(Expr(jac[i,j]))) for i in 1:size(jac,1), j in 1:size(jac,2)]
8587
exprs = vcat(var_exprs,param_exprs,vec(jac_exprs))
8688
block = expr_arr_to_block(exprs)
8789
:((J,u,p,t)->$(block))
8890
end
8991

92+
const _γ_ = Variable(:_γ_)
93+
94+
function generate_ode_iW(sys::DiffEqSystem,simplify=true)
95+
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in 1:length(sys.dvs)]
96+
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
97+
diff_idxs = map(eq->eq.args[1].diff !=nothing,sys.eqs)
98+
diff_exprs = sys.eqs[diff_idxs]
99+
jac = sys.jac
100+
iW = inv(I - _γ_*jac)
101+
102+
if simplify
103+
iW = simplify_constants.(iW)
104+
end
105+
106+
iW_t = inv(I/_γ_ - jac)
107+
if simplify
108+
iW_t = simplify_constants.(iW_t)
109+
end
110+
111+
iW_exprs = [:(iW[$i,$j] = $(Expr(iW[i,j]))) for i in 1:size(iW,1), j in 1:size(iW,2)]
112+
exprs = vcat(var_exprs,param_exprs,vec(iW_exprs))
113+
block = expr_arr_to_block(exprs)
114+
115+
iW_t_exprs = [:(iW[$i,$j] = $(Expr(iW_t[i,j]))) for i in 1:size(iW_t,1), j in 1:size(iW_t,2)]
116+
exprs = vcat(var_exprs,param_exprs,vec(iW_t_exprs))
117+
block2 = expr_arr_to_block(exprs)
118+
:((iW,u,p,_γ_,t)->$(block)),:((iW,u,p,_γ_,t)->$(block2))
119+
end
120+
90121
function DiffEqBase.DiffEqFunction(sys::DiffEqSystem)
91122
expr = generate_ode_function(sys)
92123
DiffEqFunction{true}(eval(expr))

test/system_construction.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ ModelingToolkit.generate_ode_function(de)
1818
jac_expr = ModelingToolkit.generate_ode_jacobian(de)
1919
jac = ModelingToolkit.calculate_jacobian(de)
2020
f = DiffEqFunction(de)
21-
W = I - jac
22-
iW = simplify_constants.(inv(W))
21+
ModelingToolkit.generate_ode_iW(de)
2322

2423
# Differential equation with automatic extraction of variables on rhs
2524
de2 = DiffEqSystem(eqs, [t])

0 commit comments

Comments
 (0)