Skip to content

Commit 298727a

Browse files
committed
up
2 parents 68bb1b3 + 5c1cebf commit 298727a

File tree

3 files changed

+87
-67
lines changed

3 files changed

+87
-67
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 3 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -296,73 +296,9 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
296296
end
297297

298298
function ODESystem(eqs, iv; kwargs...)
299-
eqs = collect(eqs)
300-
# NOTE: this assumes that the order of algebraic equations doesn't matter
301-
diffvars = OrderedSet()
302-
allunknowns = OrderedSet()
303-
ps = OrderedSet()
304-
# reorder equations such that it is in the form of `diffeq, algeeq`
305-
diffeq = Equation[]
306-
algeeq = Equation[]
307-
# initial loop for finding `iv`
308-
if iv === nothing
309-
for eq in eqs
310-
if !(eq.lhs isa Number) # assume eq.lhs is either Differential or Number
311-
iv = iv_from_nested_derivative(eq.lhs)
312-
break
313-
end
314-
end
315-
end
316-
iv = value(iv)
317-
iv === nothing && throw(ArgumentError("Please pass in independent variables."))
318-
compressed_eqs = Equation[] # equations that need to be expanded later, like `connect(a, b)`
319-
for eq in eqs
320-
eq.lhs isa Union{Symbolic, Number} || (push!(compressed_eqs, eq); continue)
321-
collect_vars!(allunknowns, ps, eq, iv)
322-
if isdiffeq(eq)
323-
diffvar, _ = var_from_nested_derivative(eq.lhs)
324-
if check_scope_depth(getmetadata(diffvar, SymScope, LocalScope()), 0)
325-
isequal(iv, iv_from_nested_derivative(eq.lhs)) ||
326-
throw(ArgumentError("An ODESystem can only have one independent variable."))
327-
diffvar in diffvars &&
328-
throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
329-
push!(diffvars, diffvar)
330-
end
331-
!(symtype(diffvar) === Real || eltype(symtype(var)) === Real) && throw(ArgumentError("Differential variable $var has type $(symtype(var)). Differential variables should not be concretely typed."))
332-
333-
push!(diffeq, eq)
334-
else
335-
push!(algeeq, eq)
336-
end
337-
end
338-
for eq in get(kwargs, :parameter_dependencies, Equation[])
339-
collect_vars!(allunknowns, ps, eq, iv)
340-
end
341-
for ssys in get(kwargs, :systems, ODESystem[])
342-
collect_scoped_vars!(allunknowns, ps, ssys, iv)
343-
end
344-
for v in allunknowns
345-
isdelay(v, iv) || continue
346-
collect_vars!(allunknowns, ps, arguments(v)[1], iv)
347-
end
348-
new_ps = OrderedSet()
349-
for p in ps
350-
if iscall(p) && operation(p) === getindex
351-
par = arguments(p)[begin]
352-
if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() &&
353-
all(par[i] in ps for i in eachindex(par))
354-
push!(new_ps, par)
355-
else
356-
push!(new_ps, p)
357-
end
358-
else
359-
push!(new_ps, p)
360-
end
361-
end
362-
algevars = setdiff(allunknowns, diffvars)
363-
# the orders here are very important!
364-
return ODESystem(Equation[diffeq; algeeq; compressed_eqs], iv,
365-
collect(Iterators.flatten((diffvars, algevars))), collect(new_ps); kwargs...)
299+
param_deps = get(kwargs, :parameter_dependencies, Equation[])
300+
eqs, dvs, ps = process_equations_DESystem(eqs, param_deps, iv)
301+
return ODESystem(eqs, iv, dvs, ps; kwargs...)
366302
end
367303

368304
# NOTE: equality does not check cached Jacobian

src/systems/diffeqs/sdesystem.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,21 @@ function SDESystem(sys::ODESystem, neqs; kwargs...)
273273
SDESystem(equations(sys), neqs, get_iv(sys), unknowns(sys), parameters(sys); kwargs...)
274274
end
275275

276+
function SDESystem(eqs, noiseeqs, iv; kwargs...)
277+
param_deps = get(kwargs, :parameter_dependencies, Equation[])
278+
eqs, dvs, ps = process_equations_DESystem(eqs, param_deps, iv)
279+
280+
# validate noise equations
281+
noisedvs = OrderedSet()
282+
noiseps = OrderedSet()
283+
collect_vars!(noisedvs, noiseps, noiseeqs, iv)
284+
for dv in noisedvs
285+
var Set(dvs) || throw(ArgumentError("Variable $dv in noise equations is not an unknown of the system."))
286+
end
287+
288+
return SDESystem(eqs, noiseeqs, iv, dvs, [ps; collect(noiseps)]; kwargs...)
289+
end
290+
276291
function Base.:(==)(sys1::SDESystem, sys2::SDESystem)
277292
sys1 === sys2 && return true
278293
iv1 = get_iv(sys1)

src/utils.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,3 +1186,72 @@ function guesses_from_metadata!(guesses, vars)
11861186
end
11871187
end
11881188

1189+
"""
1190+
$(TYPEDSIGNATURES)
1191+
1192+
Find all the unknowns and parameters from the equations and parameter dependencies of an ODESystem or SDESystem. Return re-ordered equations, unknowns, and parameters.
1193+
"""
1194+
function process_equations_DESystem(eqs, param_deps, iv)
1195+
eqs = collect(eqs)
1196+
1197+
diffvars = OrderedSet()
1198+
allunknowns = OrderedSet()
1199+
ps = OrderedSet()
1200+
1201+
# NOTE: this assumes that the order of algebraic equations doesn't matter
1202+
# reorder equations such that it is in the form of `diffeq, algeeq`
1203+
diffeq = Equation[]
1204+
algeeq = Equation[]
1205+
# initial loop for finding `iv`
1206+
if iv === nothing
1207+
for eq in eqs
1208+
if !(eq.lhs isa Number) # assume eq.lhs is either Differential or Number
1209+
iv = iv_from_nested_derivative(eq.lhs)
1210+
break
1211+
end
1212+
end
1213+
end
1214+
iv = value(iv)
1215+
iv === nothing && throw(ArgumentError("Please pass in independent variables."))
1216+
1217+
compressed_eqs = Equation[] # equations that need to be expanded later, like `connect(a, b)`
1218+
for eq in eqs
1219+
eq.lhs isa Union{Symbolic, Number} || (push!(compressed_eqs, eq); continue)
1220+
collect_vars!(allunknowns, ps, eq, iv)
1221+
if isdiffeq(eq)
1222+
diffvar, _ = var_from_nested_derivative(eq.lhs)
1223+
if check_scope_depth(getmetadata(diffvar, SymScope, LocalScope()), 0)
1224+
isequal(iv, iv_from_nested_derivative(eq.lhs)) ||
1225+
throw(ArgumentError("An ODESystem can only have one independent variable."))
1226+
diffvar in diffvars &&
1227+
throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
1228+
!(symtype(diffvar) === Real || eltype(symtype(diffvar)) === Real) && throw(ArgumentError("Differential variable $var has type $(symtype(diffvar)). Differential variables should not be concretely typed."))
1229+
push!(diffvars, diffvar)
1230+
end
1231+
push!(diffeq, eq)
1232+
else
1233+
push!(algeeq, eq)
1234+
end
1235+
end
1236+
for eq in param_deps
1237+
collect_vars!(allunknowns, ps, eq, iv)
1238+
end
1239+
1240+
new_ps = OrderedSet()
1241+
for p in ps
1242+
if iscall(p) && operation(p) === getindex
1243+
par = arguments(p)[begin]
1244+
if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() &&
1245+
all(par[i] in ps for i in eachindex(par))
1246+
push!(new_ps, par)
1247+
else
1248+
push!(new_ps, p)
1249+
end
1250+
else
1251+
push!(new_ps, p)
1252+
end
1253+
end
1254+
algevars = setdiff(allunknowns, diffvars)
1255+
1256+
Equation[diffeq; algeeq; compressed_eqs], collect(Iterators.flatten((diffvars, algevars))), collect(new_ps)
1257+
end

0 commit comments

Comments
 (0)