Skip to content

Commit e96bd34

Browse files
committed
Adding control jacobian and control parameter specs to sdesystem
1 parent 8133eca commit e96bd34

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ struct ODESystem <: AbstractODESystem
3131
states::Vector
3232
"""Parameter variables. Must not contain the independent variable."""
3333
ps::Vector
34+
"""Control parameters (some subset of `ps`)."""
3435
ctrls::Vector
36+
"""Observed states."""
3537
observed::Vector{Equation}
3638
"""
3739
Time-derivative matrix. Note: this field will not be defined until

src/systems/diffeqs/sdesystem.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ struct SDESystem <: AbstractODESystem
3737
states::Vector
3838
"""Parameter variables. Must not contain the independent variable."""
3939
ps::Vector
40-
observed::Vector
40+
"""Control parameters (some subset of `ps`)."""
41+
ctrls::Vector
42+
"""Observed states."""
43+
observed::Vector{Equation}
4144
"""
4245
Time-derivative matrix. Note: this field will not be defined until
4346
[`calculate_tgrad`](@ref) is called on the system.
@@ -49,6 +52,11 @@ struct SDESystem <: AbstractODESystem
4952
"""
5053
jac::RefValue
5154
"""
55+
Control Jacobian matrix. Note: this field will not be defined until
56+
[`calculate_control_jacobian`](@ref) is called on the system.
57+
"""
58+
ctrl_jac::RefValue{Any}
59+
"""
5260
`Wfact` matrix. Note: this field will not be defined until
5361
[`generate_factorized_W`](@ref) is called on the system.
5462
"""
@@ -85,7 +93,8 @@ struct SDESystem <: AbstractODESystem
8593
end
8694

8795
function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
88-
observed = [],
96+
controls = Num[],
97+
observed = Num[],
8998
systems = SDESystem[],
9099
default_u0=Dict(),
91100
default_p=Dict(),
@@ -96,6 +105,8 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
96105
iv′ = value(iv)
97106
dvs′ = value.(dvs)
98107
ps′ = value.(ps)
108+
ctrl′ = value.(controls)
109+
99110
sysnames = nameof.(systems)
100111
if length(unique(sysnames)) != length(sysnames)
101112
throw(ArgumentError("System names must be unique."))
@@ -108,9 +119,10 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
108119

109120
tgrad = RefValue(Vector{Num}(undef, 0))
110121
jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
122+
ctrl_jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
111123
Wfact = RefValue(Matrix{Num}(undef, 0, 0))
112124
Wfact_t = RefValue(Matrix{Num}(undef, 0, 0))
113-
SDESystem(deqs, neqs, iv′, dvs′, ps′, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
125+
SDESystem(deqs, neqs, iv′, dvs′, ps′, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
114126
end
115127

116128
function generate_diffusion_function(sys::SDESystem, dvs = states(sys), ps = parameters(sys); kwargs...)
@@ -157,10 +169,6 @@ function stochastic_integral_transform(sys::SDESystem, correction_factor)
157169
SDESystem(deqs,get_noiseeqs(sys),get_iv(sys),states(sys),parameters(sys))
158170
end
159171

160-
161-
162-
163-
164172
"""
165173
```julia
166174
function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.ps;

0 commit comments

Comments
 (0)