Skip to content

Commit 1f53f6a

Browse files
Merge pull request #3123 from isaacsas/dispatch_collect_vars
add trait and dispatch for collect_vars!
2 parents 28a5af3 + a8c0930 commit 1f53f6a

File tree

4 files changed

+39
-26
lines changed

4 files changed

+39
-26
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,7 @@ function ODESystem(eqs, iv; kwargs...)
319319
compressed_eqs = Equation[] # equations that need to be expanded later, like `connect(a, b)`
320320
for eq in eqs
321321
eq.lhs isa Union{Symbolic, Number} || (push!(compressed_eqs, eq); continue)
322-
collect_vars!(allunknowns, ps, eq.lhs, iv)
323-
collect_vars!(allunknowns, ps, eq.rhs, iv)
322+
collect_vars!(allunknowns, ps, eq, iv)
324323
if isdiffeq(eq)
325324
diffvar, _ = var_from_nested_derivative(eq.lhs)
326325
if check_scope_depth(getmetadata(diffvar, SymScope, LocalScope()), 0)
@@ -337,11 +336,9 @@ function ODESystem(eqs, iv; kwargs...)
337336
end
338337
for eq in get(kwargs, :parameter_dependencies, Equation[])
339338
if eq isa Pair
340-
collect_vars!(allunknowns, ps, eq[1], iv)
341-
collect_vars!(allunknowns, ps, eq[2], iv)
339+
collect_vars!(allunknowns, ps, eq, iv)
342340
else
343-
collect_vars!(allunknowns, ps, eq.lhs, iv)
344-
collect_vars!(allunknowns, ps, eq.rhs, iv)
341+
collect_vars!(allunknowns, ps, eq, iv)
345342
end
346343
end
347344
for ssys in get(kwargs, :systems, ODESystem[])

src/systems/discrete_system/discrete_system.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,7 @@ function DiscreteSystem(eqs, iv; kwargs...)
175175
ps = OrderedSet()
176176
iv = value(iv)
177177
for eq in eqs
178-
collect_vars!(allunknowns, ps, eq.lhs, iv; op = Shift)
179-
collect_vars!(allunknowns, ps, eq.rhs, iv; op = Shift)
178+
collect_vars!(allunknowns, ps, eq, iv; op = Shift)
180179
if iscall(eq.lhs) && operation(eq.lhs) isa Shift
181180
isequal(iv, operation(eq.lhs).t) ||
182181
throw(ArgumentError("A DiscreteSystem can only have one independent variable."))
@@ -187,11 +186,9 @@ function DiscreteSystem(eqs, iv; kwargs...)
187186
end
188187
for eq in get(kwargs, :parameter_dependencies, Equation[])
189188
if eq isa Pair
190-
collect_vars!(allunknowns, ps, eq[1], iv)
191-
collect_vars!(allunknowns, ps, eq[2], iv)
189+
collect_vars!(allunknowns, ps, eq, iv)
192190
else
193-
collect_vars!(allunknowns, ps, eq.lhs, iv)
194-
collect_vars!(allunknowns, ps, eq.rhs, iv)
191+
collect_vars!(allunknowns, ps, eq, iv)
195192
end
196193
end
197194
new_ps = OrderedSet()

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,16 +166,13 @@ function NonlinearSystem(eqs; kwargs...)
166166
allunknowns = OrderedSet()
167167
ps = OrderedSet()
168168
for eq in eqs
169-
collect_vars!(allunknowns, ps, eq.lhs, nothing)
170-
collect_vars!(allunknowns, ps, eq.rhs, nothing)
169+
collect_vars!(allunknowns, ps, eq, nothing)
171170
end
172171
for eq in get(kwargs, :parameter_dependencies, Equation[])
173172
if eq isa Pair
174-
collect_vars!(allunknowns, ps, eq[1], nothing)
175-
collect_vars!(allunknowns, ps, eq[2], nothing)
173+
collect_vars!(allunknowns, ps, eq, nothing)
176174
else
177-
collect_vars!(allunknowns, ps, eq.lhs, nothing)
178-
collect_vars!(allunknowns, ps, eq.rhs, nothing)
175+
collect_vars!(allunknowns, ps, eq, nothing)
179176
end
180177
end
181178
new_ps = OrderedSet()

src/utils.jl

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -492,20 +492,19 @@ recursively searches through all subsystems of `sys`, increasing the depth if it
492492
function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Differential)
493493
if has_eqs(sys)
494494
for eq in get_eqs(sys)
495-
eq isa Equation || continue
496-
eq.lhs isa Union{Symbolic, Number} || continue
497-
collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op)
498-
collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op)
495+
eqtype_supports_collect_vars(eq) || continue
496+
if eq isa Equation
497+
eq.lhs isa Union{Symbolic, Number} || continue
498+
end
499+
collect_vars!(unknowns, parameters, eq, iv; depth, op)
499500
end
500501
end
501502
if has_parameter_dependencies(sys)
502503
for eq in get_parameter_dependencies(sys)
503504
if eq isa Pair
504-
collect_vars!(unknowns, parameters, eq[1], iv; depth, op)
505-
collect_vars!(unknowns, parameters, eq[2], iv; depth, op)
505+
collect_vars!(unknowns, parameters, eq, iv; depth, op)
506506
else
507-
collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op)
508-
collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op)
507+
collect_vars!(unknowns, parameters, eq, iv; depth, op)
509508
end
510509
end
511510
end
@@ -529,6 +528,29 @@ function collect_vars!(unknowns, parameters, expr, iv; depth = 0, op = Different
529528
return nothing
530529
end
531530

531+
"""
532+
$(TYPEDSIGNATURES)
533+
534+
Indicate whether the given equation type (Equation, Pair, etc) supports `collect_vars!`.
535+
Can be dispatched by higher-level libraries to indicate support.
536+
"""
537+
eqtype_supports_collect_vars(eq) = false
538+
eqtype_supports_collect_vars(eq::Equation) = true
539+
eqtype_supports_collect_vars(eq::Pair) = true
540+
541+
function collect_vars!(unknowns, parameters, eq::Equation, iv;
542+
depth = 0, op = Differential)
543+
collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op)
544+
collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op)
545+
return nothing
546+
end
547+
548+
function collect_vars!(unknowns, parameters, p::Pair, iv; depth = 0, op = Differential)
549+
collect_vars!(unknowns, parameters, p[1], iv; depth, op)
550+
collect_vars!(unknowns, parameters, p[2], iv; depth, op)
551+
return nothing
552+
end
553+
532554
function collect_var!(unknowns, parameters, var, iv; depth = 0)
533555
isequal(var, iv) && return nothing
534556
check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing

0 commit comments

Comments
 (0)