Skip to content

Commit c8b98d3

Browse files
refactor: move flatten_equations to utils.jl
1 parent 9466ed7 commit c8b98d3

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

src/utils.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,3 +1326,31 @@ function _eq_unordered(a::AbstractArray, b::AbstractArray)
13261326
end
13271327
return true
13281328
end
1329+
1330+
"""
1331+
$(TYPEDSIGNATURES)
1332+
1333+
Given a list of equations where some may be array equations, flatten the array equations
1334+
without scalarizing occurrences of array variables and return the new list of equations.
1335+
"""
1336+
function flatten_equations(eqs::Vector{Equation})
1337+
mapreduce(vcat, eqs; init = Equation[]) do eq
1338+
islhsarr = eq.lhs isa AbstractArray || Symbolics.isarraysymbolic(eq.lhs)
1339+
isrhsarr = eq.rhs isa AbstractArray || Symbolics.isarraysymbolic(eq.rhs)
1340+
if islhsarr || isrhsarr
1341+
islhsarr && isrhsarr ||
1342+
error("""
1343+
LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must either both be array expressions \
1344+
or both scalar
1345+
""")
1346+
size(eq.lhs) == size(eq.rhs) ||
1347+
error("""
1348+
Size of LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must match: got \
1349+
$(size(eq.lhs)) and $(size(eq.rhs))
1350+
""")
1351+
return vec(collect(eq.lhs) .~ collect(eq.rhs))
1352+
else
1353+
eq
1354+
end
1355+
end
1356+
end

0 commit comments

Comments
 (0)