Skip to content

Commit 51aea4a

Browse files
Merge pull request #3266 from AayushSabharwal/as/sde-ss
feat: enable `structural_simplify(::SDESystem)`
2 parents d01496f + ea9b6bd commit 51aea4a

File tree

4 files changed

+107
-1
lines changed

4 files changed

+107
-1
lines changed

src/systems/diffeqs/sdesystem.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,47 @@ function Base.:(==)(sys1::SDESystem, sys2::SDESystem)
263263
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2)))
264264
end
265265

266+
"""
267+
function ODESystem(sys::SDESystem)
268+
269+
Convert an `SDESystem` to the equivalent `ODESystem` using `@brownian` variables instead
270+
of noise equations. The returned system will not be `iscomplete` and will not have an
271+
index cache, regardless of `iscomplete(sys)`.
272+
"""
273+
function ODESystem(sys::SDESystem)
274+
neqs = get_noiseeqs(sys)
275+
eqs = equations(sys)
276+
is_scalar_noise = get_is_scalar_noise(sys)
277+
nbrownian = if is_scalar_noise
278+
length(neqs)
279+
else
280+
size(neqs, 2)
281+
end
282+
brownvars = map(1:nbrownian) do i
283+
name = gensym(Symbol(:brown_, i))
284+
only(@brownian $name)
285+
end
286+
if is_scalar_noise
287+
brownterms = reduce(+, neqs .* brownvars; init = 0)
288+
neweqs = map(eqs) do eq
289+
eq.lhs ~ eq.rhs + brownterms
290+
end
291+
else
292+
if neqs isa AbstractVector
293+
neqs = reshape(neqs, (length(neqs), 1))
294+
end
295+
brownterms = neqs * brownvars
296+
neweqs = map(eqs, brownterms) do eq, brown
297+
eq.lhs ~ eq.rhs + brown
298+
end
299+
end
300+
newsys = ODESystem(neweqs, get_iv(sys), unknowns(sys), parameters(sys);
301+
parameter_dependencies = parameter_dependencies(sys), defaults = defaults(sys),
302+
continuous_events = continuous_events(sys), discrete_events = discrete_events(sys),
303+
name = nameof(sys), description = description(sys), metadata = get_metadata(sys))
304+
@set newsys.parent = sys
305+
end
306+
266307
function __num_isdiag_noise(mat)
267308
for i in axes(mat, 1)
268309
nnz = 0

src/systems/systems.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ function __structural_simplify(sys::JumpSystem, args...; kwargs...)
7272
return sys
7373
end
7474

75+
function __structural_simplify(sys::SDESystem, args...; kwargs...)
76+
return __structural_simplify(ODESystem(sys), args...; kwargs...)
77+
end
78+
7579
function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = false,
7680
kwargs...)
7781
sys = expand_connections(sys)

src/variables.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,9 @@ $(SIGNATURES)
489489
Define one or more Brownian variables.
490490
"""
491491
macro brownian(xs...)
492-
all(x -> x isa Symbol || Meta.isexpr(x, :call) && x.args[1] == :$, xs) ||
492+
all(
493+
x -> x isa Symbol || Meta.isexpr(x, :call) && x.args[1] == :$ || Meta.isexpr(x, :$),
494+
xs) ||
493495
error("@brownian only takes scalar expressions!")
494496
Symbolics._parse_vars(:brownian,
495497
Real,

test/sdesystem.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,3 +809,62 @@ end
809809
prob = SDEProblem(sys, [x => 1.0, y => 1.0], (0.0, 1.0))
810810
@test prob[z] 2.0
811811
end
812+
813+
@testset "SDESystem to ODESystem" begin
814+
@variables x(t) y(t) z(t)
815+
@testset "Scalar noise" begin
816+
@named sys = SDESystem([D(x) ~ x, D(y) ~ y, z ~ x + y], [x, y, 3],
817+
t, [x, y, z], [], is_scalar_noise = true)
818+
odesys = ODESystem(sys)
819+
@test odesys isa ODESystem
820+
vs = ModelingToolkit.vars(equations(odesys))
821+
nbrownian = count(
822+
v -> ModelingToolkit.getvariabletype(v) == ModelingToolkit.BROWNIAN, vs)
823+
@test nbrownian == 3
824+
for eq in equations(odesys)
825+
ModelingToolkit.isdiffeq(eq) || continue
826+
@test length(arguments(eq.rhs)) == 4
827+
end
828+
end
829+
830+
@testset "Non-scalar vector noise" begin
831+
@named sys = SDESystem([D(x) ~ x, D(y) ~ y, z ~ x + y], [x, y, 0],
832+
t, [x, y, z], [], is_scalar_noise = false)
833+
odesys = ODESystem(sys)
834+
@test odesys isa ODESystem
835+
vs = ModelingToolkit.vars(equations(odesys))
836+
nbrownian = count(
837+
v -> ModelingToolkit.getvariabletype(v) == ModelingToolkit.BROWNIAN, vs)
838+
@test nbrownian == 1
839+
for eq in equations(odesys)
840+
ModelingToolkit.isdiffeq(eq) || continue
841+
@test length(arguments(eq.rhs)) == 2
842+
end
843+
end
844+
845+
@testset "Matrix noise" begin
846+
noiseeqs = [x+y y+z z+x
847+
2y 2z 2x
848+
z+1 x+1 y+1]
849+
@named sys = SDESystem([D(x) ~ x, D(y) ~ y, D(z) ~ z], noiseeqs, t, [x, y, z], [])
850+
odesys = ODESystem(sys)
851+
@test odesys isa ODESystem
852+
vs = ModelingToolkit.vars(equations(odesys))
853+
nbrownian = count(
854+
v -> ModelingToolkit.getvariabletype(v) == ModelingToolkit.BROWNIAN, vs)
855+
@test nbrownian == 3
856+
for eq in equations(odesys)
857+
@test length(arguments(eq.rhs)) == 4
858+
end
859+
end
860+
end
861+
862+
@testset "`structural_simplify(::SDESystem)`" begin
863+
@variables x(t) y(t)
864+
@mtkbuild sys = SDESystem(
865+
[D(x) ~ x, y ~ 2x], [x, 0], t, [x, y], []; is_scalar_noise = true)
866+
@test sys isa SDESystem
867+
@test length(equations(sys)) == 1
868+
@test length(ModelingToolkit.get_noiseeqs(sys)) == 1
869+
@test length(observed(sys)) == 1
870+
end

0 commit comments

Comments
 (0)