Skip to content

Commit c69d28e

Browse files
committed
add collect_vars! equation dispatch
1 parent 1e41add commit c69d28e

File tree

4 files changed

+26
-25
lines changed

4 files changed

+26
-25
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,7 @@ function ODESystem(eqs, iv; kwargs...)
319319
compressed_eqs = Equation[] # equations that need to be expanded later, like `connect(a, b)`
320320
for eq in eqs
321321
eq.lhs isa Union{Symbolic, Number} || (push!(compressed_eqs, eq); continue)
322-
collect_vars!(allunknowns, ps, eq.lhs, iv)
323-
collect_vars!(allunknowns, ps, eq.rhs, iv)
322+
collect_vars!(allunknowns, ps, eq, iv)
324323
if isdiffeq(eq)
325324
diffvar, _ = var_from_nested_derivative(eq.lhs)
326325
if check_scope_depth(getmetadata(diffvar, SymScope, LocalScope()), 0)
@@ -337,11 +336,9 @@ function ODESystem(eqs, iv; kwargs...)
337336
end
338337
for eq in get(kwargs, :parameter_dependencies, Equation[])
339338
if eq isa Pair
340-
collect_vars!(allunknowns, ps, eq[1], iv)
341-
collect_vars!(allunknowns, ps, eq[2], iv)
339+
collect_vars!(allunknowns, ps, eq, iv)
342340
else
343-
collect_vars!(allunknowns, ps, eq.lhs, iv)
344-
collect_vars!(allunknowns, ps, eq.rhs, iv)
341+
collect_vars!(allunknowns, ps, eq, iv)
345342
end
346343
end
347344
for ssys in get(kwargs, :systems, ODESystem[])

src/systems/discrete_system/discrete_system.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,7 @@ function DiscreteSystem(eqs, iv; kwargs...)
175175
ps = OrderedSet()
176176
iv = value(iv)
177177
for eq in eqs
178-
collect_vars!(allunknowns, ps, eq.lhs, iv; op = Shift)
179-
collect_vars!(allunknowns, ps, eq.rhs, iv; op = Shift)
178+
collect_vars!(allunknowns, ps, eq, iv; op = Shift)
180179
if iscall(eq.lhs) && operation(eq.lhs) isa Shift
181180
isequal(iv, operation(eq.lhs).t) ||
182181
throw(ArgumentError("A DiscreteSystem can only have one independent variable."))
@@ -187,11 +186,9 @@ function DiscreteSystem(eqs, iv; kwargs...)
187186
end
188187
for eq in get(kwargs, :parameter_dependencies, Equation[])
189188
if eq isa Pair
190-
collect_vars!(allunknowns, ps, eq[1], iv)
191-
collect_vars!(allunknowns, ps, eq[2], iv)
189+
collect_vars!(allunknowns, ps, eq, iv)
192190
else
193-
collect_vars!(allunknowns, ps, eq.lhs, iv)
194-
collect_vars!(allunknowns, ps, eq.rhs, iv)
191+
collect_vars!(allunknowns, ps, eq, iv)
195192
end
196193
end
197194
new_ps = OrderedSet()

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,16 +166,13 @@ function NonlinearSystem(eqs; kwargs...)
166166
allunknowns = OrderedSet()
167167
ps = OrderedSet()
168168
for eq in eqs
169-
collect_vars!(allunknowns, ps, eq.lhs, nothing)
170-
collect_vars!(allunknowns, ps, eq.rhs, nothing)
169+
collect_vars!(allunknowns, ps, eq, nothing)
171170
end
172171
for eq in get(kwargs, :parameter_dependencies, Equation[])
173172
if eq isa Pair
174-
collect_vars!(allunknowns, ps, eq[1], nothing)
175-
collect_vars!(allunknowns, ps, eq[2], nothing)
173+
collect_vars!(allunknowns, ps, eq, nothing)
176174
else
177-
collect_vars!(allunknowns, ps, eq.lhs, nothing)
178-
collect_vars!(allunknowns, ps, eq.rhs, nothing)
175+
collect_vars!(allunknowns, ps, eq, nothing)
179176
end
180177
end
181178
new_ps = OrderedSet()

src/utils.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -493,19 +493,16 @@ function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Dif
493493
if has_eqs(sys)
494494
for eq in get_eqs(sys)
495495
eq isa Equation || continue
496-
eq.lhs isa Union{Symbolic, Number} || continue
497-
collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op)
498-
collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op)
496+
(eq isa Equation && eq.lhs isa Union{Symbolic, Number}) || continue
497+
collect_vars!(unknowns, parameters, eq, iv; depth, op)
499498
end
500499
end
501500
if has_parameter_dependencies(sys)
502501
for eq in get_parameter_dependencies(sys)
503502
if eq isa Pair
504-
collect_vars!(unknowns, parameters, eq[1], iv; depth, op)
505-
collect_vars!(unknowns, parameters, eq[2], iv; depth, op)
503+
collect_vars!(unknowns, parameters, eq, iv; depth, op)
506504
else
507-
collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op)
508-
collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op)
505+
collect_vars!(unknowns, parameters, eq, iv; depth, op)
509506
end
510507
end
511508
end
@@ -529,6 +526,19 @@ function collect_vars!(unknowns, parameters, expr, iv; depth = 0, op = Different
529526
return nothing
530527
end
531528

529+
function collect_vars!(unknowns, parameters, eq::Equation, iv;
530+
depth = 0, op = Differential)
531+
collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op)
532+
collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op)
533+
return nothing
534+
end
535+
536+
function collect_vars!(unknowns, parameters, p::Pair, iv; depth = 0, op = Differential)
537+
collect_vars!(unknowns, parameters, p[1], iv; depth, op)
538+
collect_vars!(unknowns, parameters, p[2], iv; depth, op)
539+
return nothing
540+
end
541+
532542
function collect_var!(unknowns, parameters, var, iv; depth = 0)
533543
isequal(var, iv) && return nothing
534544
check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing

0 commit comments

Comments
 (0)