Skip to content

Commit 5c1cebf

Browse files
committed
add constructor
1 parent 1b7261a commit 5c1cebf

File tree

3 files changed

+87
-64
lines changed

3 files changed

+87
-64
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 3 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -296,71 +296,10 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
296296
end
297297

298298
function ODESystem(eqs, iv; kwargs...)
299-
eqs = collect(eqs)
300-
# NOTE: this assumes that the order of algebraic equations doesn't matter
301-
diffvars = OrderedSet()
302-
allunknowns = OrderedSet()
303-
ps = OrderedSet()
304-
# reorder equations such that it is in the form of `diffeq, algeeq`
305-
diffeq = Equation[]
306-
algeeq = Equation[]
307-
# initial loop for finding `iv`
308-
if iv === nothing
309-
for eq in eqs
310-
if !(eq.lhs isa Number) # assume eq.lhs is either Differential or Number
311-
iv = iv_from_nested_derivative(eq.lhs)
312-
break
313-
end
314-
end
315-
end
316-
iv = value(iv)
317-
iv === nothing && throw(ArgumentError("Please pass in independent variables."))
318-
compressed_eqs = Equation[] # equations that need to be expanded later, like `connect(a, b)`
319-
for eq in eqs
320-
eq.lhs isa Union{Symbolic, Number} || (push!(compressed_eqs, eq); continue)
321-
collect_vars!(allunknowns, ps, eq, iv)
322-
if isdiffeq(eq)
323-
diffvar, _ = var_from_nested_derivative(eq.lhs)
324-
if check_scope_depth(getmetadata(diffvar, SymScope, LocalScope()), 0)
325-
isequal(iv, iv_from_nested_derivative(eq.lhs)) ||
326-
throw(ArgumentError("An ODESystem can only have one independent variable."))
327-
diffvar in diffvars &&
328-
throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
329-
push!(diffvars, diffvar)
330-
end
331-
push!(diffeq, eq)
332-
else
333-
push!(algeeq, eq)
334-
end
335-
end
336-
for eq in get(kwargs, :parameter_dependencies, Equation[])
337-
collect_vars!(allunknowns, ps, eq, iv)
338-
end
339-
for ssys in get(kwargs, :systems, ODESystem[])
340-
collect_scoped_vars!(allunknowns, ps, ssys, iv)
341-
end
342-
for v in allunknowns
343-
isdelay(v, iv) || continue
344-
collect_vars!(allunknowns, ps, arguments(v)[1], iv)
345-
end
346-
new_ps = OrderedSet()
347-
for p in ps
348-
if iscall(p) && operation(p) === getindex
349-
par = arguments(p)[begin]
350-
if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() &&
351-
all(par[i] in ps for i in eachindex(par))
352-
push!(new_ps, par)
353-
else
354-
push!(new_ps, p)
355-
end
356-
else
357-
push!(new_ps, p)
358-
end
359-
end
360-
algevars = setdiff(allunknowns, diffvars)
299+
param_deps = get(kwargs, :parameter_dependencies, Equation[])
300+
eqs, dvs, ps = process_equations_DESystem(eqs, param_deps, iv)
361301
# the orders here are very important!
362-
return ODESystem(Equation[diffeq; algeeq; compressed_eqs], iv,
363-
collect(Iterators.flatten((diffvars, algevars))), collect(new_ps); kwargs...)
302+
return ODESystem(eqs, iv, dvs, ps; kwargs...)
364303
end
365304

366305
# NOTE: equality does not check cached Jacobian

src/systems/diffeqs/sdesystem.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,21 @@ function SDESystem(sys::ODESystem, neqs; kwargs...)
273273
SDESystem(equations(sys), neqs, get_iv(sys), unknowns(sys), parameters(sys); kwargs...)
274274
end
275275

276+
function SDESystem(eqs, noiseeqs, iv; kwargs...)
277+
param_deps = get(kwargs, :parameter_dependencies, Equation[])
278+
eqs, dvs, ps = process_equations_DESystem(eqs, param_deps, iv)
279+
280+
# validate noise equations
281+
noisedvs = OrderedSet()
282+
noiseps = OrderedSet()
283+
collect_vars!(noisedvs, noiseps, noiseeqs, iv)
284+
for dv in noisedvs
285+
var Set(dvs) || throw(ArgumentError("Variable $dv in noise equations is not an unknown of the system."))
286+
end
287+
288+
return SDESystem(eqs, noiseeqs, iv, dvs, [ps; collect(noiseps)]; kwargs...)
289+
end
290+
276291
function Base.:(==)(sys1::SDESystem, sys2::SDESystem)
277292
sys1 === sys2 && return true
278293
iv1 = get_iv(sys1)

src/utils.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,3 +1185,72 @@ function guesses_from_metadata!(guesses, vars)
11851185
guesses[vars[i]] = varguesses[i]
11861186
end
11871187
end
1188+
1189+
"""
1190+
$(TYPEDSIGNATURES)
1191+
1192+
Find all the unknowns and parameters from the equations and parameter dependencies of an ODESystem or SDESystem. Return re-ordered equations, unknowns, and parameters.
1193+
"""
1194+
function process_equations_DESystem(eqs, param_deps, iv)
1195+
eqs = collect(eqs)
1196+
1197+
diffvars = OrderedSet()
1198+
allunknowns = OrderedSet()
1199+
ps = OrderedSet()
1200+
1201+
# NOTE: this assumes that the order of algebraic equations doesn't matter
1202+
# reorder equations such that it is in the form of `diffeq, algeeq`
1203+
diffeq = Equation[]
1204+
algeeq = Equation[]
1205+
# initial loop for finding `iv`
1206+
if iv === nothing
1207+
for eq in eqs
1208+
if !(eq.lhs isa Number) # assume eq.lhs is either Differential or Number
1209+
iv = iv_from_nested_derivative(eq.lhs)
1210+
break
1211+
end
1212+
end
1213+
end
1214+
iv = value(iv)
1215+
iv === nothing && throw(ArgumentError("Please pass in independent variables."))
1216+
1217+
compressed_eqs = Equation[] # equations that need to be expanded later, like `connect(a, b)`
1218+
for eq in eqs
1219+
eq.lhs isa Union{Symbolic, Number} || (push!(compressed_eqs, eq); continue)
1220+
collect_vars!(allunknowns, ps, eq, iv)
1221+
if isdiffeq(eq)
1222+
diffvar, _ = var_from_nested_derivative(eq.lhs)
1223+
if check_scope_depth(getmetadata(diffvar, SymScope, LocalScope()), 0)
1224+
isequal(iv, iv_from_nested_derivative(eq.lhs)) ||
1225+
throw(ArgumentError("An ODESystem can only have one independent variable."))
1226+
diffvar in diffvars &&
1227+
throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
1228+
push!(diffvars, diffvar)
1229+
end
1230+
push!(diffeq, eq)
1231+
else
1232+
push!(algeeq, eq)
1233+
end
1234+
end
1235+
for eq in param_deps
1236+
collect_vars!(allunknowns, ps, eq, iv)
1237+
end
1238+
1239+
new_ps = OrderedSet()
1240+
for p in ps
1241+
if iscall(p) && operation(p) === getindex
1242+
par = arguments(p)[begin]
1243+
if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() &&
1244+
all(par[i] in ps for i in eachindex(par))
1245+
push!(new_ps, par)
1246+
else
1247+
push!(new_ps, p)
1248+
end
1249+
else
1250+
push!(new_ps, p)
1251+
end
1252+
end
1253+
algevars = setdiff(allunknowns, diffvars)
1254+
1255+
Equation[diffeq; algeeq; compressed_eqs], collect(Iterators.flatten((diffvars, algevars))), collect(new_ps)
1256+
end

0 commit comments

Comments
 (0)