Skip to content

Commit d9bef69

Browse files
Remove mutability from DiffEqSystem
1 parent 2d11f50 commit d9bef69

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1-
mutable struct DiffEqSystem <: AbstractSystem
1+
using Base: RefValue
2+
3+
4+
struct DiffEqSystem <: AbstractSystem
25
eqs::Vector{Equation}
36
ivs::Vector{Variable}
47
dvs::Vector{Variable}
58
ps::Vector{Variable}
6-
jac::Matrix{Expression}
9+
jac::RefValue{Matrix{Expression}}
710
function DiffEqSystem(eqs, ivs, dvs, ps)
811
all(!isintermediate, eqs) ||
912
throw(ArgumentError("no intermediate equations permitted in DiffEqSystem"))
1013

11-
jac = Matrix{Expression}(undef, 0, 0)
14+
jac = RefValue(Matrix{Expression}(undef, 0, 0))
1215
new(eqs, ivs, dvs, ps, jac)
1316
end
1417
end
@@ -58,18 +61,18 @@ function build_equals_expr(eq::Equation)
5861
end
5962

6063
function calculate_jacobian(sys::DiffEqSystem, simplify=true)
64+
isempty(sys.jac[]) || return sys.jac[] # use cached Jacobian, if possible
6165
rhs = [eq.rhs for eq in sys.eqs]
6266

63-
sys_exprs = calculate_jacobian(rhs, sys.dvs)
64-
sys_exprs = Expression[expand_derivatives(expr) for expr in sys_exprs]
65-
sys_exprs
67+
jac = expand_derivatives.(calculate_jacobian(rhs, sys.dvs))
68+
sys.jac[] = jac # cache Jacobian
69+
return jac
6670
end
6771

6872
function generate_ode_jacobian(sys::DiffEqSystem, simplify=true)
6973
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
7074
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
7175
jac = calculate_jacobian(sys, simplify)
72-
sys.jac = jac
7376
jac_exprs = [:(J[$i,$j] = $(convert(Expr, jac[i,j]))) for i in 1:size(jac,1), j in 1:size(jac,2)]
7477
exprs = vcat(var_exprs,param_exprs,vec(jac_exprs))
7578
block = expr_arr_to_block(exprs)
@@ -79,7 +82,7 @@ end
7982
function generate_ode_iW(sys::DiffEqSystem, simplify=true)
8083
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
8184
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
82-
jac = sys.jac
85+
jac = calculate_jacobian(sys, simplify)
8386

8487
gam = Parameter(:gam)
8588

0 commit comments

Comments
 (0)