Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,7 @@ function ODESystem(eqs, iv; kwargs...)
compressed_eqs = Equation[] # equations that need to be expanded later, like `connect(a, b)`
for eq in eqs
eq.lhs isa Union{Symbolic, Number} || (push!(compressed_eqs, eq); continue)
collect_vars!(allunknowns, ps, eq.lhs, iv)
collect_vars!(allunknowns, ps, eq.rhs, iv)
collect_vars!(allunknowns, ps, eq, iv)
if isdiffeq(eq)
diffvar, _ = var_from_nested_derivative(eq.lhs)
if check_scope_depth(getmetadata(diffvar, SymScope, LocalScope()), 0)
Expand All @@ -337,11 +336,9 @@ function ODESystem(eqs, iv; kwargs...)
end
for eq in get(kwargs, :parameter_dependencies, Equation[])
if eq isa Pair
collect_vars!(allunknowns, ps, eq[1], iv)
collect_vars!(allunknowns, ps, eq[2], iv)
collect_vars!(allunknowns, ps, eq, iv)
else
collect_vars!(allunknowns, ps, eq.lhs, iv)
collect_vars!(allunknowns, ps, eq.rhs, iv)
collect_vars!(allunknowns, ps, eq, iv)
end
end
for ssys in get(kwargs, :systems, ODESystem[])
Expand Down
9 changes: 3 additions & 6 deletions src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,7 @@ function DiscreteSystem(eqs, iv; kwargs...)
ps = OrderedSet()
iv = value(iv)
for eq in eqs
collect_vars!(allunknowns, ps, eq.lhs, iv; op = Shift)
collect_vars!(allunknowns, ps, eq.rhs, iv; op = Shift)
collect_vars!(allunknowns, ps, eq, iv; op = Shift)
if iscall(eq.lhs) && operation(eq.lhs) isa Shift
isequal(iv, operation(eq.lhs).t) ||
throw(ArgumentError("A DiscreteSystem can only have one independent variable."))
Expand All @@ -187,11 +186,9 @@ function DiscreteSystem(eqs, iv; kwargs...)
end
for eq in get(kwargs, :parameter_dependencies, Equation[])
if eq isa Pair
collect_vars!(allunknowns, ps, eq[1], iv)
collect_vars!(allunknowns, ps, eq[2], iv)
collect_vars!(allunknowns, ps, eq, iv)
else
collect_vars!(allunknowns, ps, eq.lhs, iv)
collect_vars!(allunknowns, ps, eq.rhs, iv)
collect_vars!(allunknowns, ps, eq, iv)
end
end
new_ps = OrderedSet()
Expand Down
9 changes: 3 additions & 6 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,13 @@ function NonlinearSystem(eqs; kwargs...)
allunknowns = OrderedSet()
ps = OrderedSet()
for eq in eqs
collect_vars!(allunknowns, ps, eq.lhs, nothing)
collect_vars!(allunknowns, ps, eq.rhs, nothing)
collect_vars!(allunknowns, ps, eq, nothing)
end
for eq in get(kwargs, :parameter_dependencies, Equation[])
if eq isa Pair
collect_vars!(allunknowns, ps, eq[1], nothing)
collect_vars!(allunknowns, ps, eq[2], nothing)
collect_vars!(allunknowns, ps, eq, nothing)
else
collect_vars!(allunknowns, ps, eq.lhs, nothing)
collect_vars!(allunknowns, ps, eq.rhs, nothing)
collect_vars!(allunknowns, ps, eq, nothing)
end
end
new_ps = OrderedSet()
Expand Down
38 changes: 30 additions & 8 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -492,20 +492,19 @@ recursively searches through all subsystems of `sys`, increasing the depth if it
function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Differential)
if has_eqs(sys)
for eq in get_eqs(sys)
eq isa Equation || continue
eq.lhs isa Union{Symbolic, Number} || continue
collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op)
collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op)
eqtype_supports_collect_vars(eq) || continue
if eq isa Equation
eq.lhs isa Union{Symbolic, Number} || continue
end
collect_vars!(unknowns, parameters, eq, iv; depth, op)
end
end
if has_parameter_dependencies(sys)
for eq in get_parameter_dependencies(sys)
if eq isa Pair
collect_vars!(unknowns, parameters, eq[1], iv; depth, op)
collect_vars!(unknowns, parameters, eq[2], iv; depth, op)
collect_vars!(unknowns, parameters, eq, iv; depth, op)
else
collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op)
collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op)
collect_vars!(unknowns, parameters, eq, iv; depth, op)
end
end
end
Expand All @@ -529,6 +528,29 @@ function collect_vars!(unknowns, parameters, expr, iv; depth = 0, op = Different
return nothing
end

"""
$(TYPEDSIGNATURES)

Indicate whether the given equation type (Equation, Pair, etc) supports `collect_vars!`.
Can be dispatched by higher-level libraries to indicate support.
"""
eqtype_supports_collect_vars(eq) = false
eqtype_supports_collect_vars(eq::Equation) = true
eqtype_supports_collect_vars(eq::Pair) = true

function collect_vars!(unknowns, parameters, eq::Equation, iv;
depth = 0, op = Differential)
collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op)
collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op)
return nothing
end

function collect_vars!(unknowns, parameters, p::Pair, iv; depth = 0, op = Differential)
collect_vars!(unknowns, parameters, p[1], iv; depth, op)
collect_vars!(unknowns, parameters, p[2], iv; depth, op)
return nothing
end

function collect_var!(unknowns, parameters, var, iv; depth = 0)
isequal(var, iv) && return nothing
check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing
Expand Down
Loading