Skip to content

Commit 288a4a0

Browse files
committed
up
1 parent 3329ce5 commit 288a4a0

File tree

3 files changed

+67
-29
lines changed

3 files changed

+67
-29
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,38 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
296296
end
297297

298298
function ODESystem(eqs, iv; kwargs...)
299-
param_deps = get(kwargs, :parameter_dependencies, Equation[])
300-
eqs, dvs, ps = process_equations_DESystem(eqs, param_deps, iv)
301-
return ODESystem(eqs, iv, dvs, ps; kwargs...)
299+
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)
300+
301+
for eq in get(kwargs, :parameter_dependencies, Equation[])
302+
collect_vars!(allunknowns, ps, eq, iv)
303+
end
304+
305+
for ssys in get(kwargs, :systems, ODESystem[])
306+
collect_scoped_vars!(allunknowns, ps, ssys, iv)
307+
end
308+
309+
for v in allunknowns
310+
isdelay(v, iv) || continue
311+
collect_vars!(allunknowns, ps, arguments(v)[1], iv)
312+
end
313+
314+
new_ps = OrderedSet()
315+
for p in ps
316+
if iscall(p) && operation(p) === getindex
317+
par = arguments(p)[begin]
318+
if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() &&
319+
all(par[i] in ps for i in eachindex(par))
320+
push!(new_ps, par)
321+
else
322+
push!(new_ps, p)
323+
end
324+
else
325+
push!(new_ps, p)
326+
end
327+
end
328+
algevars = setdiff(allunknowns, diffvars)
329+
330+
return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars))), collect(new_ps); kwargs...)
302331
end
303332

304333
# NOTE: equality does not check cached Jacobian

src/systems/diffeqs/sdesystem.jl

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,18 +274,46 @@ function SDESystem(sys::ODESystem, neqs; kwargs...)
274274
end
275275

276276
function SDESystem(eqs::Vector{Equation}, noiseeqs::AbstractArray, iv; kwargs...)
277-
param_deps = get(kwargs, :parameter_dependencies, Equation[])
278-
eqs, dvs, ps = process_equations_DESystem(eqs, param_deps, iv)
277+
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)
278+
279+
for eq in get(kwargs, :parameter_dependencies, Equation[])
280+
collect_vars!(allunknowns, ps, eq, iv)
281+
end
282+
283+
for ssys in get(kwargs, :systems, ODESystem[])
284+
collect_scoped_vars!(allunknowns, ps, ssys, iv)
285+
end
286+
287+
for v in allunknowns
288+
isdelay(v, iv) || continue
289+
collect_vars!(allunknowns, ps, arguments(v)[1], iv)
290+
end
291+
292+
new_ps = OrderedSet()
293+
for p in ps
294+
if iscall(p) && operation(p) === getindex
295+
par = arguments(p)[begin]
296+
if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() &&
297+
all(par[i] in ps for i in eachindex(par))
298+
push!(new_ps, par)
299+
else
300+
push!(new_ps, p)
301+
end
302+
else
303+
push!(new_ps, p)
304+
end
305+
end
279306

280307
# validate noise equations
281308
noisedvs = OrderedSet()
282309
noiseps = OrderedSet()
283310
collect_vars!(noisedvs, noiseps, noiseeqs, iv)
284311
for dv in noisedvs
285-
var Set(dvs) || throw(ArgumentError("Variable $dv in noise equations is not an unknown of the system."))
312+
var allunknowns || throw(ArgumentError("Variable $dv in noise equations is not an unknown of the system."))
286313
end
314+
algevars = setdiff(allunknowns, diffvars)
287315

288-
return SDESystem(eqs, noiseeqs, iv, dvs, [ps; collect(noiseps)]; kwargs...)
316+
return SDESystem(eqs, noiseeqs, iv, Iterators.flatten((diffvars, algevars)), [ps; collect(noiseps)]; kwargs...)
289317
end
290318

291319
SDESystem(eq::Equation, noiseeqs::AbstractArray, args...; kwargs...) = SDESystem([eq], noiseeqs, args...; kwargs...)

src/utils.jl

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,9 +1189,9 @@ end
11891189
"""
11901190
$(TYPEDSIGNATURES)
11911191
1192-
Find all the unknowns and parameters from the equations and parameter dependencies of an ODESystem or SDESystem. Return re-ordered equations, unknowns, and parameters.
1192+
Find all the unknowns and parameters from the equations of a SDESystem or ODESystem. Return re-ordered equations, differential variables, all variables, and parameters.
11931193
"""
1194-
function process_equations_DESystem(eqs, param_deps, iv)
1194+
function process_equations(eqs, iv)
11951195
eqs = collect(eqs)
11961196

11971197
diffvars = OrderedSet()
@@ -1233,25 +1233,6 @@ function process_equations_DESystem(eqs, param_deps, iv)
12331233
push!(algeeq, eq)
12341234
end
12351235
end
1236-
for eq in param_deps
1237-
collect_vars!(allunknowns, ps, eq, iv)
1238-
end
1239-
1240-
new_ps = OrderedSet()
1241-
for p in ps
1242-
if iscall(p) && operation(p) === getindex
1243-
par = arguments(p)[begin]
1244-
if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() &&
1245-
all(par[i] in ps for i in eachindex(par))
1246-
push!(new_ps, par)
1247-
else
1248-
push!(new_ps, p)
1249-
end
1250-
else
1251-
push!(new_ps, p)
1252-
end
1253-
end
1254-
algevars = setdiff(allunknowns, diffvars)
12551236

1256-
Equation[diffeq; algeeq; compressed_eqs], collect(Iterators.flatten((diffvars, algevars))), collect(new_ps)
1237+
diffvars, allunknowns, ps, Equation[diffeq; algeeq; compressed_eqs]
12571238
end

0 commit comments

Comments
 (0)