Skip to content

Commit 0e9d884

Browse files
authored
Merge pull request #2051 from SciML/myb/sde
Add `System` that handles ODESystem and SDESystem
2 parents 2006fa4 + 306428f commit 0e9d884

File tree

8 files changed

+156
-41
lines changed

8 files changed

+156
-41
lines changed

src/ModelingToolkit.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ include("discretedomain.jl")
147147
include("systems/systemstructure.jl")
148148
using .SystemStructures
149149
include("systems/clock_inference.jl")
150+
include("systems/systems.jl")
150151

151152
include("debugging.jl")
152153
include("systems/alias_elimination.jl")
@@ -162,8 +163,9 @@ end
162163

163164
export AbstractTimeDependentSystem, AbstractTimeIndependentSystem,
164165
AbstractMultivariateSystem
166+
165167
export ODESystem, ODEFunction, ODEFunctionExpr, ODEProblemExpr, convert_system,
166-
add_accumulations
168+
add_accumulations, System
167169
export DAEFunctionExpr, DAEProblemExpr
168170
export SDESystem, SDEFunction, SDEFunctionExpr, SDEProblemExpr
169171
export SystemStructure
@@ -212,7 +214,7 @@ export simplify, substitute
212214
export build_function
213215
export modelingtoolkitize
214216

215-
export @variables, @parameters, @constants
217+
export @variables, @parameters, @constants, @brownian
216218
export @named, @nonamespace, @namespace, extend, compose, complete
217219
export debug_system
218220

src/parameters.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,26 @@
11
import SymbolicUtils: symtype, term, hasmetadata, issym
2-
struct MTKParameterCtx end
2+
@enum VariableType VARIABLE PARAMETER BROWNIAN
3+
struct MTKVariableTypeCtx end
4+
5+
getvariabletype(x, def = VARIABLE) = getmetadata(unwrap(x), MTKVariableTypeCtx, def)
36

47
function isparameter(x)
58
x = unwrap(x)
69

7-
if x isa Symbolic && (isp = getmetadata(x, MTKParameterCtx, nothing)) !== nothing
8-
return isp
10+
if x isa Symbolic && (varT = getvariabletype(x, nothing)) !== nothing
11+
return varT === PARAMETER
912
#TODO: Delete this branch
1013
elseif x isa Symbolic && Symbolics.getparent(x, false) !== false
1114
p = Symbolics.getparent(x)
1215
isparameter(p) ||
1316
(hasmetadata(p, Symbolics.VariableSource) &&
1417
getmetadata(p, Symbolics.VariableSource)[1] == :parameters)
1518
elseif istree(x) && operation(x) isa Symbolic
16-
getmetadata(x, MTKParameterCtx, false) ||
17-
isparameter(operation(x))
19+
varT === PARAMETER || isparameter(operation(x))
1820
elseif istree(x) && operation(x) == (getindex)
1921
isparameter(arguments(x)[1])
2022
elseif x isa Symbolic
21-
getmetadata(x, MTKParameterCtx, false)
23+
varT === PARAMETER
2224
else
2325
false
2426
end
@@ -35,7 +37,7 @@ function toparam(s)
3537
elseif s isa AbstractArray
3638
map(toparam, s)
3739
else
38-
setmetadata(s, MTKParameterCtx, true)
40+
setmetadata(s, MTKVariableTypeCtx, PARAMETER)
3941
end
4042
end
4143
toparam(s::Num) = wrap(toparam(value(s)))
@@ -45,13 +47,13 @@ toparam(s::Num) = wrap(toparam(value(s)))
4547
4648
Maps the variable to a state.
4749
"""
48-
tovar(s::Symbolic) = setmetadata(s, MTKParameterCtx, false)
50+
tovar(s::Symbolic) = setmetadata(s, MTKVariableTypeCtx, VARIABLE)
4951
tovar(s::Num) = Num(tovar(value(s)))
5052

5153
"""
5254
$(SIGNATURES)
5355
54-
Define one or more known variables.
56+
Define one or more known parameters.
5557
"""
5658
macro parameters(xs...)
5759
Symbolics._parse_vars(:parameters,

src/systems/abstractsystem.jl

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,27 +1136,6 @@ function debug_system(sys::AbstractSystem)
11361136
return sys
11371137
end
11381138

1139-
"""
1140-
$(SIGNATURES)
1141-
1142-
Structurally simplify algebraic equations in a system and compute the
1143-
topological sort of the observed equations. When `simplify=true`, the `simplify`
1144-
function will be applied during the tearing process. It also takes kwargs
1145-
`allow_symbolic=false` and `allow_parameter=true` which limits the coefficient
1146-
types during tearing.
1147-
1148-
The optional argument `io` may take a tuple `(inputs, outputs)`.
1149-
This will convert all `inputs` to parameters and allow them to be unconnected, i.e.,
1150-
simplification will allow models where `n_states = n_equations - n_inputs`.
1151-
"""
1152-
function structural_simplify(sys::AbstractSystem, io = nothing; simplify = false,
1153-
kwargs...)
1154-
sys = expand_connections(sys)
1155-
sys isa DiscreteSystem && return sys
1156-
state = TearingState(sys)
1157-
structural_simplify!(state, io; simplify, kwargs...)
1158-
end
1159-
11601139
function eliminate_constants(sys::AbstractSystem)
11611140
if has_eqs(sys)
11621141
eqs = get_eqs(sys)

src/systems/diffeqs/sdesystem.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ struct SDESystem <: AbstractODESystem
134134
end
135135
end
136136

137-
function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
137+
function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dvs, ps;
138138
controls = Num[],
139139
observed = Num[],
140140
systems = SDESystem[],
@@ -190,6 +190,19 @@ function SDESystem(sys::ODESystem, neqs; kwargs...)
190190
SDESystem(equations(sys), neqs, get_iv(sys), states(sys), parameters(sys); kwargs...)
191191
end
192192

193+
function Base.:(==)(sys1::SDESystem, sys2::SDESystem)
194+
sys1 === sys2 && return true
195+
iv1 = get_iv(sys1)
196+
iv2 = get_iv(sys2)
197+
isequal(iv1, iv2) &&
198+
isequal(nameof(sys1), nameof(sys2)) &&
199+
isequal(get_eqs(sys1), get_eqs(sys2)) &&
200+
isequal(get_noiseeqs(sys1), get_noiseeqs(sys2)) &&
201+
_eq_unordered(get_states(sys1), get_states(sys2)) &&
202+
_eq_unordered(get_ps(sys1), get_ps(sys2)) &&
203+
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2)))
204+
end
205+
193206
function generate_diffusion_function(sys::SDESystem, dvs = states(sys),
194207
ps = parameters(sys); kwargs...)
195208
return build_function(get_noiseeqs(sys),

src/systems/systems.jl

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
function System(eqs::AbstractVector{<:Equation}, iv = nothing, args...; name = nothing,
2+
kw...)
3+
ODESystem(eqs, iv, args...; name, checks = false)
4+
end
5+
6+
"""
7+
$(SIGNATURES)
8+
9+
Structurally simplify algebraic equations in a system and compute the
10+
topological sort of the observed equations. When `simplify=true`, the `simplify`
11+
function will be applied during the tearing process. It also takes kwargs
12+
`allow_symbolic=false` and `allow_parameter=true` which limits the coefficient
13+
types during tearing.
14+
15+
The optional argument `io` may take a tuple `(inputs, outputs)`.
16+
This will convert all `inputs` to parameters and allow them to be unconnected, i.e.,
17+
simplification will allow models where `n_states = n_equations - n_inputs`.
18+
"""
19+
function structural_simplify(sys::AbstractSystem, io = nothing; simplify = false,
20+
kwargs...)
21+
sys = expand_connections(sys)
22+
sys isa DiscreteSystem && return sys
23+
state = TearingState(sys)
24+
25+
@unpack structure, fullvars = state
26+
@unpack graph, var_to_diff, var_types = structure
27+
eqs = equations(state)
28+
brown_vars = Int[]
29+
new_idxs = zeros(Int, length(var_types))
30+
idx = 0
31+
for (i, vt) in enumerate(var_types)
32+
if vt === BROWNIAN
33+
push!(brown_vars, i)
34+
else
35+
new_idxs[i] = (idx += 1)
36+
end
37+
end
38+
if isempty(brown_vars)
39+
return structural_simplify!(state, io; simplify, kwargs...)
40+
else
41+
Is = Int[]
42+
Js = Int[]
43+
vals = Num[]
44+
new_eqs = copy(eqs)
45+
dvar2eq = Dict{Any, Int}()
46+
for (v, dv) in enumerate(var_to_diff)
47+
dv === nothing && continue
48+
deqs = 𝑑neighbors(graph, dv)
49+
if length(deqs) != 1
50+
error("$(eqs[deqs]) is not handled.")
51+
end
52+
dvar2eq[fullvars[dv]] = only(deqs)
53+
end
54+
for (j, bj) in enumerate(brown_vars), i in 𝑑neighbors(graph, bj)
55+
push!(Is, i)
56+
push!(Js, j)
57+
eq = new_eqs[i]
58+
brown = fullvars[bj]
59+
(coeff, residual, islinear) = Symbolics.linear_expansion(eq, brown)
60+
islinear || error("$brown isn't linear in $eq")
61+
new_eqs[i] = 0 ~ residual
62+
push!(vals, coeff)
63+
end
64+
g = Matrix(sparse(Is, Js, vals))
65+
sys = state.sys
66+
@set! sys.eqs = new_eqs
67+
@set! sys.states = [v
68+
for (i, v) in enumerate(fullvars)
69+
if !iszero(new_idxs[i]) && invview(var_to_diff)[i] === nothing]
70+
# TODO: IO is not handled.
71+
ode_sys = structural_simplify(sys, io; simplify, kwargs...)
72+
eqs = equations(ode_sys)
73+
sorted_g_rows = zeros(Num, length(eqs), size(g, 2))
74+
for (i, eq) in enumerate(eqs)
75+
dvar = eq.lhs
76+
# differential equations always precede algebraic equations
77+
_iszero(dvar) && break
78+
g_row = get(dvar2eq, dvar, 0)
79+
iszero(g_row) && error("$dvar isn't handled.")
80+
g_row > size(g, 1) && continue
81+
@views copyto!(sorted_g_rows[i, :], g[g_row, :])
82+
end
83+
84+
return SDESystem(full_equations(ode_sys), sorted_g_rows,
85+
get_iv(ode_sys), states(ode_sys), parameters(ode_sys);
86+
name = nameof(ode_sys))
87+
end
88+
end

src/systems/systemstructure.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten,
99
value, InvalidSystemException, isdifferential, _iszero,
1010
isparameter, isconstant,
1111
independent_variables, SparseMatrixCLIL, AbstractSystem,
12-
equations, isirreducible, input_timedomain, TimeDomain
12+
equations, isirreducible, input_timedomain, TimeDomain,
13+
VariableType, getvariabletype
1314
using ..BipartiteGraphs
1415
import ..BipartiteGraphs: invview, complete
1516
using Graphs
@@ -31,8 +32,6 @@ export isdiffvar, isdervar, isalgvar, isdiffeq, isalgeq, algeqs, is_only_discret
3132
export dervars_range, diffvars_range, algvars_range
3233
export DiffGraph, complete!
3334

34-
@enum VariableType::Int8 DIFFERENTIAL_VARIABLE ALGEBRAIC_VARIABLE DERIVATIVE_VARIABLE
35-
3635
struct DiffGraph <: Graphs.AbstractGraph{Int}
3736
primal_to_diff::Vector{Union{Int, Nothing}}
3837
diff_to_primal::Union{Nothing, Vector{Union{Int, Nothing}}}
@@ -149,13 +148,15 @@ Base.@kwdef mutable struct SystemStructure
149148
# or as `torn` to assert that tearing has run.
150149
graph::BipartiteGraph{Int, Nothing}
151150
solvable_graph::Union{BipartiteGraph{Int, Nothing}, Nothing}
151+
var_types::Union{Vector{VariableType}, Nothing}
152152
only_discrete::Bool
153153
end
154154

155155
function Base.copy(structure::SystemStructure)
156+
var_types = structure.var_types === nothing ? nothing : copy(structure.var_types)
156157
SystemStructure(copy(structure.var_to_diff), copy(structure.eq_to_diff),
157158
copy(structure.graph), copy(structure.solvable_graph),
158-
structure.only_discrete)
159+
var_types, structure.only_discrete)
159160
end
160161

161162
is_only_discrete(s::SystemStructure) = s.only_discrete
@@ -230,9 +231,11 @@ function TearingState(sys; quick_cancel = false, check = true)
230231
symbolic_incidence = []
231232
fullvars = []
232233
var_counter = Ref(0)
233-
addvar! = let fullvars = fullvars, var_counter = var_counter
234+
var_types = VariableType[]
235+
addvar! = let fullvars = fullvars, var_counter = var_counter, var_types = var_types
234236
var -> begin get!(var2idx, var) do
235237
push!(fullvars, var)
238+
push!(var_types, getvariabletype(var))
236239
var_counter[] += 1
237240
end end
238241
end
@@ -331,7 +334,10 @@ function TearingState(sys; quick_cancel = false, check = true)
331334
push!(sorted_fullvars, v)
332335
end
333336
end
334-
fullvars = collect(sorted_fullvars)
337+
new_fullvars = collect(sorted_fullvars)
338+
sortperm = indexin(new_fullvars, fullvars)
339+
fullvars = new_fullvars
340+
var_types = var_types[sortperm]
335341
var2idx = Dict(fullvars .=> eachindex(fullvars))
336342
dervaridxs = 1:length(dervaridxs)
337343

@@ -358,7 +364,8 @@ function TearingState(sys; quick_cancel = false, check = true)
358364

359365
return TearingState(sys, fullvars,
360366
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
361-
complete(graph), nothing, false), Any[])
367+
complete(graph), nothing, var_types, false),
368+
Any[])
362369
end
363370

364371
function lower_order_var(dervar)

src/variables.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,27 @@ function isintegervar(x)
390390
p === nothing || (x = p)
391391
return Symbolics.getmetadata(x, VariableInteger, false)
392392
end
393+
394+
## Brownian
395+
"""
396+
tobrownian(s::Sym)
397+
398+
Maps the brownianiable to a state.
399+
"""
400+
tobrownian(s::Symbolic) = setmetadata(s, MTKVariableTypeCtx, BROWNIAN)
401+
tobrownian(s::Num) = Num(tobrownian(value(s)))
402+
isbrownian(s) = getvariabletype(s) === BROWNIAN
403+
404+
"""
405+
$(SIGNATURES)
406+
407+
Define one or more Brownian variables.
408+
"""
409+
macro brownian(xs...)
410+
all(x -> x isa Symbol || Meta.isexpr(x, :call) && x.args[1] == :$, xs) ||
411+
error("@brownian only takes scalar expressions!")
412+
Symbolics._parse_vars(:brownian,
413+
Real,
414+
xs,
415+
tobrownian) |> esc
416+
end

test/sdesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ fdif!(du, u0, p, t)
460460
]
461461
sys1 = SDESystem(eqs_short, noiseeqs, t, [x, y, z], [σ, ρ, β], name = :sys1)
462462
sys2 = SDESystem(eqs_short, noiseeqs, t, [x, y, z], [σ, ρ, β], name = :sys1)
463-
@test_throws ArgumentError SDESystem([sys2.y ~ sys1.z], t, [], [], [],
463+
@test_throws ArgumentError SDESystem([sys2.y ~ sys1.z], [], t, [], [],
464464
systems = [sys1, sys2], name = :foo)
465465
end
466466

0 commit comments

Comments
 (0)