1
- mutable struct DiffEqSystem <: AbstractSystem
2
- eqs:: Vector{Equation}
3
- ivs:: Vector{Variable}
1
+ using Base: RefValue
2
+
3
+
4
+ isintermediate (eq:: Equation ) = ! (isa (eq. lhs, Operation) && isa (eq. lhs. op, Differential))
5
+
6
+ struct DiffEq # D(x) = t
7
+ D:: Differential # D
8
+ var:: Variable # x
9
+ rhs:: Expression # t
10
+ end
11
+ function Base. convert (:: Type{DiffEq} , eq:: Equation )
12
+ isintermediate (eq) && throw (ArgumentError (" intermediate equation received" ))
13
+ return DiffEq (eq. lhs. op, eq. lhs. args[1 ], eq. rhs)
14
+ end
15
+ Base.:(== )(a:: DiffEq , b:: DiffEq ) = (a. D, a. var, a. rhs) == (b. D, b. var, b. rhs)
16
+ get_args (eq:: DiffEq ) = Expression[eq. var, eq. rhs]
17
+
18
+ struct DiffEqSystem <: AbstractSystem
19
+ eqs:: Vector{DiffEq}
20
+ iv:: Variable
4
21
dvs:: Vector{Variable}
5
22
ps:: Vector{Variable}
6
- jac:: Matrix{Expression}
7
- function DiffEqSystem (eqs, ivs, dvs, ps, jac)
8
- all (! isintermediate, eqs) ||
9
- throw (ArgumentError (" no intermediate equations permitted in DiffEqSystem" ))
10
-
11
- new (eqs, ivs, dvs, ps, jac)
23
+ jac:: RefValue{Matrix{Expression}}
24
+ function DiffEqSystem (eqs, iv, dvs, ps)
25
+ jac = RefValue (Matrix {Expression} (undef, 0 , 0 ))
26
+ new (eqs, iv, dvs, ps, jac)
12
27
end
13
28
end
14
29
15
- DiffEqSystem (eqs, ivs, dvs, ps) = DiffEqSystem (eqs, ivs, dvs, ps, Matrix {Expression} (undef,0 ,0 ))
16
-
17
30
function DiffEqSystem (eqs)
18
31
dvs, = extract_elements (eqs, [_is_dependent])
19
32
ivs = unique (vcat ((dv. dependents for dv ∈ dvs). .. ))
20
- ps, = extract_elements (eqs, [_is_parameter (ivs)])
21
- DiffEqSystem (eqs, ivs, dvs, ps, Matrix {Expression} (undef,0 ,0 ))
33
+ length (ivs) == 1 || throw (ArgumentError (" one independent variable currently supported" ))
34
+ iv = first (ivs)
35
+ ps, = extract_elements (eqs, [_is_parameter (iv)])
36
+ DiffEqSystem (eqs, iv, dvs, ps)
22
37
end
23
38
24
- function DiffEqSystem (eqs, ivs )
25
- dvs, ps = extract_elements (eqs, [_is_dependent, _is_parameter (ivs )])
26
- DiffEqSystem (eqs, ivs , dvs, ps, Matrix {Expression} (undef, 0 , 0 ) )
39
+ function DiffEqSystem (eqs, iv )
40
+ dvs, ps = extract_elements (eqs, [_is_dependent, _is_parameter (iv )])
41
+ DiffEqSystem (eqs, iv , dvs, ps)
27
42
end
28
43
29
- isintermediate (eq:: Equation ) = ! (isa (eq. lhs, Operation) && isa (eq. lhs. op, Differential))
30
-
31
44
32
- function generate_ode_function (sys:: DiffEqSystem ;version = ArrayFunction)
45
+ function generate_ode_function (sys:: DiffEqSystem ; version:: FunctionVersion = ArrayFunction)
33
46
var_exprs = [:($ (sys. dvs[i]. name) = u[$ i]) for i in eachindex (sys. dvs)]
34
47
param_exprs = [:($ (sys. ps[i]. name) = p[$ i]) for i in eachindex (sys. ps)]
35
48
sys_exprs = build_equals_expr .(sys. eqs)
36
- if version == ArrayFunction
37
- dvar_exprs = [:(du[$ i] = $ (Symbol (" $(sys. dvs[i]. name) _$(sys. ivs[ 1 ] . name) " ))) for i in eachindex (sys. dvs)]
49
+ if version === ArrayFunction
50
+ dvar_exprs = [:(du[$ i] = $ (Symbol (" $(sys. dvs[i]. name) _$(sys. iv . name) " ))) for i in eachindex (sys. dvs)]
38
51
exprs = vcat (var_exprs,param_exprs,sys_exprs,dvar_exprs)
39
52
block = expr_arr_to_block (exprs)
40
53
:((du,u,p,t)-> $ (toexpr (block)))
41
- elseif version == SArrayFunction
42
- dvar_exprs = [:($ (Symbol (" $(sys. dvs[i]. name) _$(sys. ivs[ 1 ] . name) " ))) for i in eachindex (sys. dvs)]
54
+ elseif version === SArrayFunction
55
+ dvar_exprs = [:($ (Symbol (" $(sys. dvs[i]. name) _$(sys. iv . name) " ))) for i in eachindex (sys. dvs)]
43
56
svector_expr = quote
44
57
E = eltype (tuple ($ (dvar_exprs... )))
45
58
T = StaticArrays. similar_type (typeof (u), E)
@@ -51,26 +64,24 @@ function generate_ode_function(sys::DiffEqSystem;version = ArrayFunction)
51
64
end
52
65
end
53
66
54
- function build_equals_expr (eq:: Equation )
55
- @assert ! isintermediate (eq)
56
-
57
- lhs = Symbol (eq. lhs. args[1 ]. name, :_ , eq. lhs. op. x. name)
67
+ function build_equals_expr (eq:: DiffEq )
68
+ lhs = Symbol (eq. var. name, :_ , eq. D. x. name)
58
69
return :($ lhs = $ (convert (Expr, eq. rhs)))
59
70
end
60
71
61
72
function calculate_jacobian (sys:: DiffEqSystem , simplify= true )
73
+ isempty (sys. jac[]) || return sys. jac[] # use cached Jacobian, if possible
62
74
rhs = [eq. rhs for eq in sys. eqs]
63
75
64
- sys_exprs = calculate_jacobian (rhs, sys. dvs)
65
- sys_exprs = Expression[ expand_derivatives (expr) for expr in sys_exprs]
66
- sys_exprs
76
+ jac = expand_derivatives .( calculate_jacobian (rhs, sys. dvs) )
77
+ sys . jac[] = jac # cache Jacobian
78
+ return jac
67
79
end
68
80
69
81
function generate_ode_jacobian (sys:: DiffEqSystem , simplify= true )
70
82
var_exprs = [:($ (sys. dvs[i]. name) = u[$ i]) for i in eachindex (sys. dvs)]
71
83
param_exprs = [:($ (sys. ps[i]. name) = p[$ i]) for i in eachindex (sys. ps)]
72
84
jac = calculate_jacobian (sys, simplify)
73
- sys. jac = jac
74
85
jac_exprs = [:(J[$ i,$ j] = $ (convert (Expr, jac[i,j]))) for i in 1 : size (jac,1 ), j in 1 : size (jac,2 )]
75
86
exprs = vcat (var_exprs,param_exprs,vec (jac_exprs))
76
87
block = expr_arr_to_block (exprs)
80
91
function generate_ode_iW (sys:: DiffEqSystem , simplify= true )
81
92
var_exprs = [:($ (sys. dvs[i]. name) = u[$ i]) for i in eachindex (sys. dvs)]
82
93
param_exprs = [:($ (sys. ps[i]. name) = p[$ i]) for i in eachindex (sys. ps)]
83
- jac = sys. jac
94
+ jac = calculate_jacobian ( sys, simplify)
84
95
85
96
gam = Parameter (:gam )
86
97
@@ -109,12 +120,12 @@ function generate_ode_iW(sys::DiffEqSystem, simplify=true)
109
120
:((iW,u,p,gam,t)-> $ (block)),:((iW,u,p,gam,t)-> $ (block2))
110
121
end
111
122
112
- function DiffEqBase. ODEFunction (sys:: DiffEqSystem ;version = ArrayFunction,kwargs ... )
113
- expr = generate_ode_function (sys;version= version,kwargs ... )
114
- if version == ArrayFunction
115
- ODEFunction {true} (eval (expr))
116
- elseif version == SArrayFunction
117
- ODEFunction {false} (eval (expr))
123
+ function DiffEqBase. ODEFunction (sys:: DiffEqSystem ; version:: FunctionVersion = ArrayFunction)
124
+ expr = generate_ode_function (sys; version = version)
125
+ if version === ArrayFunction
126
+ ODEFunction {true} (eval (expr))
127
+ elseif version === SArrayFunction
128
+ ODEFunction {false} (eval (expr))
118
129
end
119
130
end
120
131
0 commit comments