From c69d28e0d829e0f9a21a56ea8c2364c052f0a367 Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Tue, 15 Oct 2024 10:59:51 -0400 Subject: [PATCH 1/6] add collect_vars! equation dispatch --- src/systems/diffeqs/odesystem.jl | 9 +++---- .../discrete_system/discrete_system.jl | 9 +++---- src/systems/nonlinear/nonlinearsystem.jl | 9 +++---- src/utils.jl | 24 +++++++++++++------ 4 files changed, 26 insertions(+), 25 deletions(-) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 68ea0b48e3..d1cd01ce7b 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -319,8 +319,7 @@ function ODESystem(eqs, iv; kwargs...) compressed_eqs = Equation[] # equations that need to be expanded later, like `connect(a, b)` for eq in eqs eq.lhs isa Union{Symbolic, Number} || (push!(compressed_eqs, eq); continue) - collect_vars!(allunknowns, ps, eq.lhs, iv) - collect_vars!(allunknowns, ps, eq.rhs, iv) + collect_vars!(allunknowns, ps, eq, iv) if isdiffeq(eq) diffvar, _ = var_from_nested_derivative(eq.lhs) if check_scope_depth(getmetadata(diffvar, SymScope, LocalScope()), 0) @@ -337,11 +336,9 @@ function ODESystem(eqs, iv; kwargs...) end for eq in get(kwargs, :parameter_dependencies, Equation[]) if eq isa Pair - collect_vars!(allunknowns, ps, eq[1], iv) - collect_vars!(allunknowns, ps, eq[2], iv) + collect_vars!(allunknowns, ps, eq, iv) else - collect_vars!(allunknowns, ps, eq.lhs, iv) - collect_vars!(allunknowns, ps, eq.rhs, iv) + collect_vars!(allunknowns, ps, eq, iv) end end for ssys in get(kwargs, :systems, ODESystem[]) diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index 7103cfca80..99f76a8ce9 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -175,8 +175,7 @@ function DiscreteSystem(eqs, iv; kwargs...) ps = OrderedSet() iv = value(iv) for eq in eqs - collect_vars!(allunknowns, ps, eq.lhs, iv; op = Shift) - collect_vars!(allunknowns, ps, eq.rhs, iv; op = Shift) + collect_vars!(allunknowns, ps, eq, iv; op = Shift) if iscall(eq.lhs) && operation(eq.lhs) isa Shift isequal(iv, operation(eq.lhs).t) || throw(ArgumentError("A DiscreteSystem can only have one independent variable.")) @@ -187,11 +186,9 @@ function DiscreteSystem(eqs, iv; kwargs...) end for eq in get(kwargs, :parameter_dependencies, Equation[]) if eq isa Pair - collect_vars!(allunknowns, ps, eq[1], iv) - collect_vars!(allunknowns, ps, eq[2], iv) + collect_vars!(allunknowns, ps, eq, iv) else - collect_vars!(allunknowns, ps, eq.lhs, iv) - collect_vars!(allunknowns, ps, eq.rhs, iv) + collect_vars!(allunknowns, ps, eq, iv) end end new_ps = OrderedSet() diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 46bf032d6f..c649b9b287 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -166,16 +166,13 @@ function NonlinearSystem(eqs; kwargs...) allunknowns = OrderedSet() ps = OrderedSet() for eq in eqs - collect_vars!(allunknowns, ps, eq.lhs, nothing) - collect_vars!(allunknowns, ps, eq.rhs, nothing) + collect_vars!(allunknowns, ps, eq, nothing) end for eq in get(kwargs, :parameter_dependencies, Equation[]) if eq isa Pair - collect_vars!(allunknowns, ps, eq[1], nothing) - collect_vars!(allunknowns, ps, eq[2], nothing) + collect_vars!(allunknowns, ps, eq, nothing) else - collect_vars!(allunknowns, ps, eq.lhs, nothing) - collect_vars!(allunknowns, ps, eq.rhs, nothing) + collect_vars!(allunknowns, ps, eq, nothing) end end new_ps = OrderedSet() diff --git a/src/utils.jl b/src/utils.jl index e8ed131d78..62445d5814 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -493,19 +493,16 @@ function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Dif if has_eqs(sys) for eq in get_eqs(sys) eq isa Equation || continue - eq.lhs isa Union{Symbolic, Number} || continue - collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op) - collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op) + (eq isa Equation && eq.lhs isa Union{Symbolic, Number}) || continue + collect_vars!(unknowns, parameters, eq, iv; depth, op) end end if has_parameter_dependencies(sys) for eq in get_parameter_dependencies(sys) if eq isa Pair - collect_vars!(unknowns, parameters, eq[1], iv; depth, op) - collect_vars!(unknowns, parameters, eq[2], iv; depth, op) + collect_vars!(unknowns, parameters, eq, iv; depth, op) else - collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op) - collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op) + collect_vars!(unknowns, parameters, eq, iv; depth, op) end end end @@ -529,6 +526,19 @@ function collect_vars!(unknowns, parameters, expr, iv; depth = 0, op = Different return nothing end +function collect_vars!(unknowns, parameters, eq::Equation, iv; + depth = 0, op = Differential) + collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op) + collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op) + return nothing +end + +function collect_vars!(unknowns, parameters, p::Pair, iv; depth = 0, op = Differential) + collect_vars!(unknowns, parameters, p[1], iv; depth, op) + collect_vars!(unknowns, parameters, p[2], iv; depth, op) + return nothing +end + function collect_var!(unknowns, parameters, var, iv; depth = 0) isequal(var, iv) && return nothing check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing From dbe211d80c4008726329d15ed029c8f21d4ec45e Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Tue, 15 Oct 2024 11:07:17 -0400 Subject: [PATCH 2/6] add collect_vars equation dispatch --- src/utils.jl | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 62445d5814..1eaa2fa8eb 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -492,7 +492,7 @@ recursively searches through all subsystems of `sys`, increasing the depth if it function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Differential) if has_eqs(sys) for eq in get_eqs(sys) - eq isa Equation || continue + eqtype_supports_collect_vars(eq) || continue (eq isa Equation && eq.lhs isa Union{Symbolic, Number}) || continue collect_vars!(unknowns, parameters, eq, iv; depth, op) end @@ -526,6 +526,16 @@ function collect_vars!(unknowns, parameters, expr, iv; depth = 0, op = Different return nothing end +""" + $(TYPEDSIGNATURES) + +Indicate whether the given equation type (Equation, Pair, etc) supports `collect_vars!`. Can +be dispatched by higher-level libraries to indicate support. +""" +eqtype_supports_collect_vars(eq) = false +eqtype_supports_collect_vars(eq::Equation) = true +eqtype_supports_collect_vars(eq::Pair) = true + function collect_vars!(unknowns, parameters, eq::Equation, iv; depth = 0, op = Differential) collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op) From 1209b1e0fd0d9678520a5ee2c4a2b605ecd8179a Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Tue, 15 Oct 2024 11:14:37 -0400 Subject: [PATCH 3/6] comment tweak --- src/utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 1eaa2fa8eb..5eb411a1e5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -529,8 +529,8 @@ end """ $(TYPEDSIGNATURES) -Indicate whether the given equation type (Equation, Pair, etc) supports `collect_vars!`. Can -be dispatched by higher-level libraries to indicate support. +Indicate whether the given equation type (Equation, Pair, etc) supports `collect_vars!`. +Can be dispatched by higher-level libraries to indicate support. """ eqtype_supports_collect_vars(eq) = false eqtype_supports_collect_vars(eq::Equation) = true From 9cd326bc9b64654a2b7a39f73dedde1d259b896c Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Tue, 15 Oct 2024 11:19:44 -0400 Subject: [PATCH 4/6] format --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 5eb411a1e5..1072c37353 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -536,7 +536,7 @@ eqtype_supports_collect_vars(eq) = false eqtype_supports_collect_vars(eq::Equation) = true eqtype_supports_collect_vars(eq::Pair) = true -function collect_vars!(unknowns, parameters, eq::Equation, iv; +function collect_vars!(unknowns, parameters, eq::Equation, iv; depth = 0, op = Differential) collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op) collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op) From 557e20d194af7970142e98951887bcbf711cbb8b Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Tue, 15 Oct 2024 11:24:11 -0400 Subject: [PATCH 5/6] fix check --- src/utils.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 1072c37353..67074246b9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -493,7 +493,9 @@ function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Dif if has_eqs(sys) for eq in get_eqs(sys) eqtype_supports_collect_vars(eq) || continue - (eq isa Equation && eq.lhs isa Union{Symbolic, Number}) || continue + if eq isa Equation + eq.lhs isa Union{Symbolic, Number} || continue + end collect_vars!(unknowns, parameters, eq, iv; depth, op) end end From a8c09306c5bb4a2198a6936d461d2b4cc157faaa Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Tue, 15 Oct 2024 12:15:25 -0400 Subject: [PATCH 6/6] format --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 67074246b9..c13d8a480a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -493,7 +493,7 @@ function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Dif if has_eqs(sys) for eq in get_eqs(sys) eqtype_supports_collect_vars(eq) || continue - if eq isa Equation + if eq isa Equation eq.lhs isa Union{Symbolic, Number} || continue end collect_vars!(unknowns, parameters, eq, iv; depth, op)