@@ -128,13 +128,18 @@ struct SDESystem <: AbstractODESystem
128
128
The hierarchical parent system before simplification.
129
129
"""
130
130
parent:: Any
131
-
131
+ """
132
+ Signal for whether the noise equations should be treated as a scalar process. This should only
133
+ be `true` when `noiseeqs isa Vector`.
134
+ """
135
+ is_scalar_noise:: Bool
136
+
132
137
function SDESystem (tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed,
133
138
tgrad,
134
139
jac,
135
140
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
136
141
cevents, devents, parameter_dependencies, metadata = nothing , gui_metadata = nothing ,
137
- complete = false , index_cache = nothing , parent = nothing ;
142
+ complete = false , index_cache = nothing , parent = nothing , is_scalar_noise = false ;
138
143
checks:: Union{Bool, Int} = true )
139
144
if checks == true || (checks & CheckComponents) > 0
140
145
check_independent_variables ([iv])
@@ -146,6 +151,9 @@ struct SDESystem <: AbstractODESystem
146
151
throw (ArgumentError (" Noise equations ill-formed. Number of rows must match number of drift equations. size(neqs,1) = $(size (neqs,1 )) != length(deqs) = $(length (deqs)) " ))
147
152
end
148
153
check_equations (equations (cevents), iv)
154
+ if is_scalar_noise && neqs isa AbstractMatrix
155
+ throw (ArgumentError (" Noise equations ill-formed. Recieved a matrix of noise equations of size $(size (neqs)) , but `is_scalar_noise` was set to `true`. Scalar noise is only compatible with an `AbstractVector` of noise equations." ))
156
+ end
149
157
end
150
158
if checks == true || (checks & CheckUnits) > 0
151
159
u = __get_unit_type (dvs, ps, iv)
@@ -154,7 +162,7 @@ struct SDESystem <: AbstractODESystem
154
162
new (tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
155
163
ctrl_jac,
156
164
Wfact, Wfact_t, name, systems, defaults, connector_type, cevents, devents,
157
- parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent)
165
+ parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent, is_scalar_noise )
158
166
end
159
167
end
160
168
@@ -173,7 +181,11 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
173
181
discrete_events = nothing ,
174
182
parameter_dependencies = nothing ,
175
183
metadata = nothing ,
176
- gui_metadata = nothing )
184
+ gui_metadata = nothing ,
185
+ complete = false ,
186
+ index_cache = nothing ,
187
+ parent = nothing ,
188
+ is_scalar_noise= false )
177
189
name === nothing &&
178
190
throw (ArgumentError (" The `name` keyword must be provided. Please consider using the `@named` macro" ))
179
191
iv′ = value (iv)
@@ -208,9 +220,10 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
208
220
parameter_dependencies, ps′ = process_parameter_dependencies (
209
221
parameter_dependencies, ps′)
210
222
SDESystem (Threads. atomic_add! (SYSTEM_COUNT, UInt (1 )),
211
- deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
212
- ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
213
- cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata; checks = checks)
223
+ deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
224
+ ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
225
+ cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata,
226
+ complete, index_cache, parent, is_scalar_noise; checks = checks)
214
227
end
215
228
216
229
function SDESystem (sys:: ODESystem , neqs; kwargs... )
@@ -225,6 +238,7 @@ function Base.:(==)(sys1::SDESystem, sys2::SDESystem)
225
238
isequal (nameof (sys1), nameof (sys2)) &&
226
239
isequal (get_eqs (sys1), get_eqs (sys2)) &&
227
240
isequal (get_noiseeqs (sys1), get_noiseeqs (sys2)) &&
241
+ isequal (get_is_scalar_noise (sys1), get_is_scalar_noise (sys2)) &&
228
242
_eq_unordered (get_unknowns (sys1), get_unknowns (sys2)) &&
229
243
_eq_unordered (get_ps (sys1), get_ps (sys2)) &&
230
244
all (s1 == s2 for (s1, s2) in zip (get_systems (sys1), get_systems (sys2)))
@@ -601,6 +615,9 @@ function SDEFunctionExpr(sys::SDESystem, args...; kwargs...)
601
615
SDEFunctionExpr {true} (sys, args... ; kwargs... )
602
616
end
603
617
618
+
619
+ function scalar_noise end # defined in ../ext/MTKDiffEqNoiseProcess.jl
620
+
604
621
function DiffEqBase. SDEProblem {iip, specialize} (
605
622
sys:: SDESystem , u0map = [], tspan = get_tspan (sys),
606
623
parammap = DiffEqBase. NullParameters ();
@@ -616,16 +633,24 @@ function DiffEqBase.SDEProblem{iip, specialize}(
616
633
sparsenoise === nothing && (sparsenoise = get (kwargs, :sparse , false ))
617
634
618
635
noiseeqs = get_noiseeqs (sys)
636
+ is_scalar_noise = get_is_scalar_noise (sys)
619
637
if noiseeqs isa AbstractVector
620
638
noise_rate_prototype = nothing
639
+ if is_scalar_noise
640
+ noise = scalar_noise ()
641
+ else
642
+ noise = nothing
643
+ end
621
644
elseif sparsenoise
622
645
I, J, V = findnz (SparseArrays. sparse (noiseeqs))
623
646
noise_rate_prototype = SparseArrays. sparse (I, J, zero (eltype (u0)))
647
+ noise = nothing
624
648
else
625
649
noise_rate_prototype = zeros (eltype (u0), size (noiseeqs))
650
+ noise = nothing
626
651
end
627
652
628
- SDEProblem {iip} (f, u0, tspan, p; callback = cbs,
653
+ SDEProblem {iip} (f, u0, tspan, p; callback = cbs, noise,
629
654
noise_rate_prototype = noise_rate_prototype, kwargs... )
630
655
end
631
656
@@ -693,8 +718,12 @@ function SDEProblemExpr{iip}(sys::SDESystem, u0map, tspan,
693
718
sparsenoise === nothing && (sparsenoise = get (kwargs, :sparse , false ))
694
719
695
720
noiseeqs = get_noiseeqs (sys)
721
+ is_scalar_noise = get_is_scalar_noise (sys)
696
722
if noiseeqs isa AbstractVector
697
723
noise_rate_prototype = nothing
724
+ if is_scalar_noise
725
+ noise = scalar_noise ()
726
+ end
698
727
elseif sparsenoise
699
728
I, J, V = findnz (SparseArrays. sparse (noiseeqs))
700
729
noise_rate_prototype = SparseArrays. sparse (I, J, zero (eltype (u0)))
@@ -708,7 +737,8 @@ function SDEProblemExpr{iip}(sys::SDESystem, u0map, tspan,
708
737
tspan = $ tspan
709
738
p = $ p
710
739
noise_rate_prototype = $ noise_rate_prototype
711
- SDEProblem (f, u0, tspan, p; noise_rate_prototype = noise_rate_prototype,
740
+ noise = $ noise
741
+ SDEProblem (f, u0, tspan, p; noise_rate_prototype = noise_rate_prototype, noise = noise,
712
742
$ (kwargs... ))
713
743
end
714
744
! linenumbers ? Base. remove_linenums! (ex) : ex
0 commit comments