Skip to content

Commit a2acbe5

Browse files
feat: add ability to convert SDESystem to equivalent ODESystem
1 parent 6d6383a commit a2acbe5

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

src/systems/diffeqs/sdesystem.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,47 @@ function Base.:(==)(sys1::SDESystem, sys2::SDESystem)
263263
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2)))
264264
end
265265

266+
"""
267+
function ODESystem(sys::SDESystem)
268+
269+
Convert an `SDESystem` to the equivalent `ODESystem` using `@brownian` variables instead
270+
of noise equations. The returned system will not be `iscomplete` and will not have an
271+
index cache, regardless of `iscomplete(sys)`.
272+
"""
273+
function ODESystem(sys::SDESystem)
274+
neqs = get_noiseeqs(sys)
275+
eqs = equations(sys)
276+
is_scalar_noise = get_is_scalar_noise(sys)
277+
nbrownian = if is_scalar_noise
278+
length(neqs)
279+
else
280+
size(neqs, 2)
281+
end
282+
brownvars = map(1:nbrownian) do i
283+
name = gensym(Symbol(:brown_, i))
284+
only(@brownian $name)
285+
end
286+
if is_scalar_noise
287+
brownterms = reduce(+, neqs .* brownvars; init = 0)
288+
neweqs = map(eqs) do eq
289+
eq.lhs ~ eq.rhs + brownterms
290+
end
291+
else
292+
if neqs isa AbstractVector
293+
neqs = reshape(neqs, (length(neqs), 1))
294+
end
295+
brownterms = neqs * brownvars
296+
neweqs = map(eqs, brownterms) do eq, brown
297+
eq.lhs ~ eq.rhs + brown
298+
end
299+
end
300+
newsys = ODESystem(neweqs, get_iv(sys), unknowns(sys), parameters(sys);
301+
parameter_dependencies = parameter_dependencies(sys), defaults = defaults(sys),
302+
continuous_events = continuous_events(sys), discrete_events = discrete_events(sys),
303+
name = nameof(sys), description = description(sys), metadata = get_metadata(sys))
304+
@set newsys.parent = sys
305+
end
306+
266307
function __num_isdiag_noise(mat)
267308
for i in axes(mat, 1)
268309
nnz = 0

test/sdesystem.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,3 +809,52 @@ end
809809
prob = SDEProblem(sys, [x => 1.0, y => 1.0], (0.0, 1.0))
810810
@test prob[z] 2.0
811811
end
812+
813+
@testset "SDESystem to ODESystem" begin
814+
@variables x(t) y(t) z(t)
815+
@testset "Scalar noise" begin
816+
@named sys = SDESystem([D(x) ~ x, D(y) ~ y, z ~ x + y], [x, y, 3],
817+
t, [x, y, z], [], is_scalar_noise = true)
818+
odesys = ODESystem(sys)
819+
@test odesys isa ODESystem
820+
vs = ModelingToolkit.vars(equations(odesys))
821+
nbrownian = count(
822+
v -> ModelingToolkit.getvariabletype(v) == ModelingToolkit.BROWNIAN, vs)
823+
@test nbrownian == 3
824+
for eq in equations(odesys)
825+
ModelingToolkit.isdiffeq(eq) || continue
826+
@test length(arguments(eq.rhs)) == 4
827+
end
828+
end
829+
830+
@testset "Non-scalar vector noise" begin
831+
@named sys = SDESystem([D(x) ~ x, D(y) ~ y, z ~ x + y], [x, y, 0],
832+
t, [x, y, z], [], is_scalar_noise = false)
833+
odesys = ODESystem(sys)
834+
@test odesys isa ODESystem
835+
vs = ModelingToolkit.vars(equations(odesys))
836+
nbrownian = count(
837+
v -> ModelingToolkit.getvariabletype(v) == ModelingToolkit.BROWNIAN, vs)
838+
@test nbrownian == 1
839+
for eq in equations(odesys)
840+
ModelingToolkit.isdiffeq(eq) || continue
841+
@test length(arguments(eq.rhs)) == 2
842+
end
843+
end
844+
845+
@testset "Matrix noise" begin
846+
noiseeqs = [x+y y+z z+x
847+
2y 2z 2x
848+
z+1 x+1 y+1]
849+
@named sys = SDESystem([D(x) ~ x, D(y) ~ y, D(z) ~ z], noiseeqs, t, [x, y, z], [])
850+
odesys = ODESystem(sys)
851+
@test odesys isa ODESystem
852+
vs = ModelingToolkit.vars(equations(odesys))
853+
nbrownian = count(
854+
v -> ModelingToolkit.getvariabletype(v) == ModelingToolkit.BROWNIAN, vs)
855+
@test nbrownian == 3
856+
for eq in equations(odesys)
857+
@test length(arguments(eq.rhs)) == 4
858+
end
859+
end
860+
end

0 commit comments

Comments
 (0)