@@ -4,6 +4,12 @@ mutable struct DiffEqSystem <: AbstractSystem
4
4
dvs:: Vector{Variable}
5
5
ps:: Vector{Variable}
6
6
jac:: Matrix{Expression}
7
+ function DiffEqSystem (eqs, ivs, dvs, ps, jac)
8
+ all (! isintermediate, eqs) ||
9
+ throw (ArgumentError (" no intermediate equations permitted in DiffEqSystem" ))
10
+
11
+ new (eqs, ivs, dvs, ps, jac)
12
+ end
7
13
end
8
14
9
15
DiffEqSystem (eqs, ivs, dvs, ps) = DiffEqSystem (eqs, ivs, dvs, ps, Matrix {Expression} (undef,0 ,0 ))
@@ -20,6 +26,9 @@ function DiffEqSystem(eqs, ivs)
20
26
DiffEqSystem (eqs, ivs, dvs, ps, Matrix {Expression} (undef,0 ,0 ))
21
27
end
22
28
29
+ isintermediate (eq:: Equation ) = ! (isa (eq. lhs, Operation) && isa (eq. lhs. op, Differential))
30
+
31
+
23
32
function generate_ode_function (sys:: DiffEqSystem ;version = ArrayFunction)
24
33
var_exprs = [:($ (sys. dvs[i]. name) = u[$ i]) for i in eachindex (sys. dvs)]
25
34
param_exprs = [:($ (sys. ps[i]. name) = p[$ i]) for i in eachindex (sys. ps)]
@@ -42,8 +51,6 @@ function generate_ode_function(sys::DiffEqSystem;version = ArrayFunction)
42
51
end
43
52
end
44
53
45
- isintermediate (eq:: Equation ) = ! (isa (eq. lhs, Operation) && isa (eq. lhs. op, Differential))
46
-
47
54
function build_equals_expr (eq:: Equation )
48
55
@assert ! isintermediate (eq)
49
56
@@ -52,13 +59,7 @@ function build_equals_expr(eq::Equation)
52
59
end
53
60
54
61
function calculate_jacobian (sys:: DiffEqSystem , simplify= true )
55
- calcs, diff_exprs = partition (isintermediate, sys. eqs)
56
- rhs = [eq. rhs for eq in diff_exprs]
57
-
58
- # Handle intermediate calculations by substitution
59
- for calc ∈ calcs
60
- find_replace! .(rhs, calc. lhs, calc. rhs)
61
- end
62
+ rhs = [eq. rhs for eq in sys. eqs]
62
63
63
64
sys_exprs = calculate_jacobian (rhs, sys. dvs)
64
65
sys_exprs = Expression[expand_derivatives (expr) for expr in sys_exprs]
68
69
function generate_ode_jacobian (sys:: DiffEqSystem , simplify= true )
69
70
var_exprs = [:($ (sys. dvs[i]. name) = u[$ i]) for i in eachindex (sys. dvs)]
70
71
param_exprs = [:($ (sys. ps[i]. name) = p[$ i]) for i in eachindex (sys. ps)]
71
- diff_exprs = filter (! isintermediate, sys. eqs)
72
72
jac = calculate_jacobian (sys, simplify)
73
73
sys. jac = jac
74
74
jac_exprs = [:(J[$ i,$ j] = $ (convert (Expr, jac[i,j]))) for i in 1 : size (jac,1 ), j in 1 : size (jac,2 )]
80
80
function generate_ode_iW (sys:: DiffEqSystem , simplify= true )
81
81
var_exprs = [:($ (sys. dvs[i]. name) = u[$ i]) for i in eachindex (sys. dvs)]
82
82
param_exprs = [:($ (sys. ps[i]. name) = p[$ i]) for i in eachindex (sys. ps)]
83
- diff_exprs = filter (! isintermediate, sys. eqs)
84
83
jac = sys. jac
85
84
86
85
gam = Parameter (:gam )
0 commit comments