Skip to content

Commit e1a9ac4

Browse files
MasonProtterChrisRackauckas
authored andcommitted
tell SDEProblem that the system contains scalar noise
1 parent 00f42a0 commit e1a9ac4

File tree

6 files changed

+61
-21
lines changed

6 files changed

+61
-21
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,12 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
5656
[weakdeps]
5757
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
5858
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
59+
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
5960

6061
[extensions]
6162
MTKBifurcationKitExt = "BifurcationKit"
6263
MTKDeepDiffsExt = "DeepDiffs"
64+
MTKDiffEqNoiseProcess = "DiffEqNoiseProcess"
6365

6466
[compat]
6567
AbstractTrees = "0.3, 0.4"

ext/MTKDiffEqNoiseProcess.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module MTKDiffEqNoiseProcess
2+
3+
using ModelingToolkit: ModelingToolkit
4+
using DiffEqNoiseProcess: WienerProcess
5+
6+
ModelingToolkit.scalar_noise() = WienerProcess(0.0, 0.0, 0.0)
7+
8+
end

src/systems/abstractsystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,8 @@ for prop in [:eqs
655655
:solved_unknowns
656656
:split_idxs
657657
:parent
658-
:index_cache]
658+
:index_cache
659+
:is_scalar_noise]
659660
fname_get = Symbol(:get_, prop)
660661
fname_has = Symbol(:has_, prop)
661662
@eval begin

src/systems/diffeqs/sdesystem.jl

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,18 @@ struct SDESystem <: AbstractODESystem
128128
The hierarchical parent system before simplification.
129129
"""
130130
parent::Any
131-
131+
"""
132+
Signal for whether the noise equations should be treated as a scalar process. This should only
133+
be `true` when `noiseeqs isa Vector`.
134+
"""
135+
is_scalar_noise::Bool
136+
132137
function SDESystem(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed,
133138
tgrad,
134139
jac,
135140
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
136141
cevents, devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing,
137-
complete = false, index_cache = nothing, parent = nothing;
142+
complete = false, index_cache = nothing, parent = nothing, is_scalar_noise=false;
138143
checks::Union{Bool, Int} = true)
139144
if checks == true || (checks & CheckComponents) > 0
140145
check_independent_variables([iv])
@@ -146,6 +151,9 @@ struct SDESystem <: AbstractODESystem
146151
throw(ArgumentError("Noise equations ill-formed. Number of rows must match number of drift equations. size(neqs,1) = $(size(neqs,1)) != length(deqs) = $(length(deqs))"))
147152
end
148153
check_equations(equations(cevents), iv)
154+
if is_scalar_noise && neqs isa AbstractMatrix
155+
throw(ArgumentError("Noise equations ill-formed. Recieved a matrix of noise equations of size $(size(neqs)), but `is_scalar_noise` was set to `true`. Scalar noise is only compatible with an `AbstractVector` of noise equations."))
156+
end
149157
end
150158
if checks == true || (checks & CheckUnits) > 0
151159
u = __get_unit_type(dvs, ps, iv)
@@ -154,7 +162,7 @@ struct SDESystem <: AbstractODESystem
154162
new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
155163
ctrl_jac,
156164
Wfact, Wfact_t, name, systems, defaults, connector_type, cevents, devents,
157-
parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent)
165+
parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent, is_scalar_noise)
158166
end
159167
end
160168

@@ -173,7 +181,11 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
173181
discrete_events = nothing,
174182
parameter_dependencies = nothing,
175183
metadata = nothing,
176-
gui_metadata = nothing)
184+
gui_metadata = nothing,
185+
complete = false,
186+
index_cache = nothing,
187+
parent = nothing,
188+
is_scalar_noise=false)
177189
name === nothing &&
178190
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
179191
iv′ = value(iv)
@@ -208,9 +220,10 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
208220
parameter_dependencies, ps′ = process_parameter_dependencies(
209221
parameter_dependencies, ps′)
210222
SDESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
211-
deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
212-
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
213-
cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata; checks = checks)
223+
deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
224+
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
225+
cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata,
226+
complete, index_cache, parent, is_scalar_noise; checks = checks)
214227
end
215228

216229
function SDESystem(sys::ODESystem, neqs; kwargs...)
@@ -225,6 +238,7 @@ function Base.:(==)(sys1::SDESystem, sys2::SDESystem)
225238
isequal(nameof(sys1), nameof(sys2)) &&
226239
isequal(get_eqs(sys1), get_eqs(sys2)) &&
227240
isequal(get_noiseeqs(sys1), get_noiseeqs(sys2)) &&
241+
isequal(get_is_scalar_noise(sys1), get_is_scalar_noise(sys2)) &&
228242
_eq_unordered(get_unknowns(sys1), get_unknowns(sys2)) &&
229243
_eq_unordered(get_ps(sys1), get_ps(sys2)) &&
230244
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2)))
@@ -601,6 +615,9 @@ function SDEFunctionExpr(sys::SDESystem, args...; kwargs...)
601615
SDEFunctionExpr{true}(sys, args...; kwargs...)
602616
end
603617

618+
619+
function scalar_noise end # defined in ../ext/MTKDiffEqNoiseProcess.jl
620+
604621
function DiffEqBase.SDEProblem{iip, specialize}(
605622
sys::SDESystem, u0map = [], tspan = get_tspan(sys),
606623
parammap = DiffEqBase.NullParameters();
@@ -616,16 +633,24 @@ function DiffEqBase.SDEProblem{iip, specialize}(
616633
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))
617634

618635
noiseeqs = get_noiseeqs(sys)
636+
is_scalar_noise = get_is_scalar_noise(sys)
619637
if noiseeqs isa AbstractVector
620638
noise_rate_prototype = nothing
639+
if is_scalar_noise
640+
noise = scalar_noise()
641+
else
642+
noise = nothing
643+
end
621644
elseif sparsenoise
622645
I, J, V = findnz(SparseArrays.sparse(noiseeqs))
623646
noise_rate_prototype = SparseArrays.sparse(I, J, zero(eltype(u0)))
647+
noise = nothing
624648
else
625649
noise_rate_prototype = zeros(eltype(u0), size(noiseeqs))
650+
noise = nothing
626651
end
627652

628-
SDEProblem{iip}(f, u0, tspan, p; callback = cbs,
653+
SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise,
629654
noise_rate_prototype = noise_rate_prototype, kwargs...)
630655
end
631656

@@ -693,8 +718,12 @@ function SDEProblemExpr{iip}(sys::SDESystem, u0map, tspan,
693718
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))
694719

695720
noiseeqs = get_noiseeqs(sys)
721+
is_scalar_noise = get_is_scalar_noise(sys)
696722
if noiseeqs isa AbstractVector
697723
noise_rate_prototype = nothing
724+
if is_scalar_noise
725+
noise = scalar_noise()
726+
end
698727
elseif sparsenoise
699728
I, J, V = findnz(SparseArrays.sparse(noiseeqs))
700729
noise_rate_prototype = SparseArrays.sparse(I, J, zero(eltype(u0)))
@@ -708,7 +737,8 @@ function SDEProblemExpr{iip}(sys::SDESystem, u0map, tspan,
708737
tspan = $tspan
709738
p = $p
710739
noise_rate_prototype = $noise_rate_prototype
711-
SDEProblem(f, u0, tspan, p; noise_rate_prototype = noise_rate_prototype,
740+
noise = $noise
741+
SDEProblem(f, u0, tspan, p; noise_rate_prototype = noise_rate_prototype, noise = noise,
712742
$(kwargs...))
713743
end
714744
!linenumbers ? Base.remove_linenums!(ex) : ex

src/systems/systems.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,21 +128,20 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal
128128
@views copyto!(sorted_g_rows[i, :], g[g_row, :])
129129
end
130130
# Fix for https://github.com/SciML/ModelingToolkit.jl/issues/2490
131-
noise_eqs = if isdiag(sorted_g_rows)
131+
if isdiag(sorted_g_rows)
132132
# If the noise matrix is diagonal, then we just give solver just takes a vector column of equations
133133
# and it interprets that as diagonal noise.
134-
diag(sorted_g_rows)
134+
noise_eqs = diag(sorted_g_rows)
135+
is_scalar_noise = false
135136
elseif sorted_g_rows isa AbstractMatrix && size(sorted_g_rows, 2) == 1
136-
##-------------------------------------------------------------------------------
137-
## TODO: re-enable this code once we add a way to signal that the noise is scalar
138-
# sorted_g_rows[:, 1]
139-
##-------------------------------------------------------------------------------
140-
sorted_g_rows
137+
noise_eqs = sorted_g_rows[:, 1]
138+
is_scalar_noise = true
141139
else
142-
sorted_g_rows
140+
noise_eqs = sorted_g_rows
141+
is_scalar_noise = false
143142
end
144143
return SDESystem(full_equations(ode_sys), noise_eqs,
145144
get_iv(ode_sys), unknowns(ode_sys), parameters(ode_sys);
146-
name = nameof(ode_sys))
145+
name = nameof(ode_sys), is_scalar_noise)
147146
end
148147
end

test/sdesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ let
681681
]
682682
prob = SDEProblem(de, u0map, (0.0, 100.0), parammap)
683683
# TODO: re-enable this when we support scalar noise
684-
@test_broken solve(prob, SOSRI()).retcode == ReturnCode.Success
684+
@test solve(prob, SOSRI()).retcode == ReturnCode.Success
685685
end
686686

687687
let # test to make sure that scalar noise always recieve the same kicks
@@ -692,7 +692,7 @@ let # test to make sure that scalar noise always recieve the same kicks
692692

693693
@mtkbuild de = System(eqs, t)
694694
prob = SDEProblem(de, [x => 0, y => 0], (0.0, 10.0), [])
695-
sol = solve(prob, ImplicitEM())
695+
sol = solve(prob, SOSRI())
696696
@test sol[end][1] == sol[end][2]
697697
end
698698

0 commit comments

Comments
 (0)