|
1 |
| -mutable struct DiffEqSystem <: AbstractSystem |
| 1 | +using Base: RefValue |
| 2 | + |
| 3 | + |
| 4 | +struct DiffEqSystem <: AbstractSystem |
2 | 5 | eqs::Vector{Equation}
|
3 | 6 | ivs::Vector{Variable}
|
4 | 7 | dvs::Vector{Variable}
|
5 | 8 | ps::Vector{Variable}
|
6 |
| - jac::Matrix{Expression} |
| 9 | + jac::RefValue{Matrix{Expression}} |
7 | 10 | function DiffEqSystem(eqs, ivs, dvs, ps)
|
8 | 11 | all(!isintermediate, eqs) ||
|
9 | 12 | throw(ArgumentError("no intermediate equations permitted in DiffEqSystem"))
|
10 | 13 |
|
11 |
| - jac = Matrix{Expression}(undef, 0, 0) |
| 14 | + jac = RefValue(Matrix{Expression}(undef, 0, 0)) |
12 | 15 | new(eqs, ivs, dvs, ps, jac)
|
13 | 16 | end
|
14 | 17 | end
|
@@ -58,18 +61,18 @@ function build_equals_expr(eq::Equation)
|
58 | 61 | end
|
59 | 62 |
|
60 | 63 | function calculate_jacobian(sys::DiffEqSystem, simplify=true)
|
| 64 | + isempty(sys.jac[]) || return sys.jac[] # use cached Jacobian, if possible |
61 | 65 | rhs = [eq.rhs for eq in sys.eqs]
|
62 | 66 |
|
63 |
| - sys_exprs = calculate_jacobian(rhs, sys.dvs) |
64 |
| - sys_exprs = Expression[expand_derivatives(expr) for expr in sys_exprs] |
65 |
| - sys_exprs |
| 67 | + jac = expand_derivatives.(calculate_jacobian(rhs, sys.dvs)) |
| 68 | + sys.jac[] = jac # cache Jacobian |
| 69 | + return jac |
66 | 70 | end
|
67 | 71 |
|
68 | 72 | function generate_ode_jacobian(sys::DiffEqSystem, simplify=true)
|
69 | 73 | var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
|
70 | 74 | param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
|
71 | 75 | jac = calculate_jacobian(sys, simplify)
|
72 |
| - sys.jac = jac |
73 | 76 | jac_exprs = [:(J[$i,$j] = $(convert(Expr, jac[i,j]))) for i in 1:size(jac,1), j in 1:size(jac,2)]
|
74 | 77 | exprs = vcat(var_exprs,param_exprs,vec(jac_exprs))
|
75 | 78 | block = expr_arr_to_block(exprs)
|
|
79 | 82 | function generate_ode_iW(sys::DiffEqSystem, simplify=true)
|
80 | 83 | var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
|
81 | 84 | param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
|
82 |
| - jac = sys.jac |
| 85 | + jac = calculate_jacobian(sys, simplify) |
83 | 86 |
|
84 | 87 | gam = Parameter(:gam)
|
85 | 88 |
|
|
0 commit comments