@@ -37,7 +37,10 @@ struct SDESystem <: AbstractODESystem
37
37
states:: Vector
38
38
""" Parameter variables. Must not contain the independent variable."""
39
39
ps:: Vector
40
- observed:: Vector
40
+ """ Control parameters (some subset of `ps`)."""
41
+ ctrls:: Vector
42
+ """ Observed states."""
43
+ observed:: Vector{Equation}
41
44
"""
42
45
Time-derivative matrix. Note: this field will not be defined until
43
46
[`calculate_tgrad`](@ref) is called on the system.
@@ -49,6 +52,11 @@ struct SDESystem <: AbstractODESystem
49
52
"""
50
53
jac:: RefValue
51
54
"""
55
+ Control Jacobian matrix. Note: this field will not be defined until
56
+ [`calculate_control_jacobian`](@ref) is called on the system.
57
+ """
58
+ ctrl_jac:: RefValue{Any}
59
+ """
52
60
`Wfact` matrix. Note: this field will not be defined until
53
61
[`generate_factorized_W`](@ref) is called on the system.
54
62
"""
@@ -85,7 +93,8 @@ struct SDESystem <: AbstractODESystem
85
93
end
86
94
87
95
function SDESystem (deqs:: AbstractVector{<:Equation} , neqs, iv, dvs, ps;
88
- observed = [],
96
+ controls = Num[],
97
+ observed = Num[],
89
98
systems = SDESystem[],
90
99
default_u0= Dict (),
91
100
default_p= Dict (),
@@ -96,6 +105,8 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
96
105
iv′ = value (iv)
97
106
dvs′ = value .(dvs)
98
107
ps′ = value .(ps)
108
+ ctrl′ = value .(controls)
109
+
99
110
sysnames = nameof .(systems)
100
111
if length (unique (sysnames)) != length (sysnames)
101
112
throw (ArgumentError (" System names must be unique." ))
@@ -108,9 +119,10 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
108
119
109
120
tgrad = RefValue (Vector {Num} (undef, 0 ))
110
121
jac = RefValue {Any} (Matrix {Num} (undef, 0 , 0 ))
122
+ ctrl_jac = RefValue {Any} (Matrix {Num} (undef, 0 , 0 ))
111
123
Wfact = RefValue (Matrix {Num} (undef, 0 , 0 ))
112
124
Wfact_t = RefValue (Matrix {Num} (undef, 0 , 0 ))
113
- SDESystem (deqs, neqs, iv′, dvs′, ps′, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
125
+ SDESystem (deqs, neqs, iv′, dvs′, ps′, ctrl′, observed, tgrad, jac, ctrl_jac , Wfact, Wfact_t, name, systems, defaults, connection_type)
114
126
end
115
127
116
128
function generate_diffusion_function (sys:: SDESystem , dvs = states (sys), ps = parameters (sys); kwargs... )
@@ -157,10 +169,6 @@ function stochastic_integral_transform(sys::SDESystem, correction_factor)
157
169
SDESystem (deqs,get_noiseeqs (sys),get_iv (sys),states (sys),parameters (sys))
158
170
end
159
171
160
-
161
-
162
-
163
-
164
172
"""
165
173
```julia
166
174
function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.ps;
0 commit comments