@@ -239,12 +239,12 @@ Base.isequal(x, ::Symbolic) = false
239239Base. isequal (:: Symbolic , :: Missing ) = false
240240Base. isequal (:: Missing , :: Symbolic ) = false
241241Base. isequal (:: Symbolic , :: Symbolic ) = false
242- coeff_isequal (a, b) = isequal (a, b) || ((a isa AbstractFloat || b isa AbstractFloat) && (a== b))
243- function _allarequal (xs, ys):: Bool
242+ coeff_isequal (a, b; comparator = isequal) = comparator (a, b) || ((a isa AbstractFloat || b isa AbstractFloat) && (a== b))
243+ function _allarequal (xs, ys; comparator = isequal ):: Bool
244244 N = length (xs)
245245 length (ys) == N || return false
246246 for n = 1 : N
247- isequal (xs[n], ys[n]) || return false
247+ comparator (xs[n], ys[n]) || return false
248248 end
249249 return true
250250end
@@ -258,19 +258,19 @@ function Base.isequal(a::BasicSymbolic{T}, b::BasicSymbolic{S}) where {T,S}
258258 T === S || return false
259259 return _isequal (a, b, E):: Bool
260260end
261- function _isequal (a, b, E)
261+ function _isequal (a, b, E; comparator = isequal )
262262 if E === SYM
263263 nameof (a) === nameof (b)
264264 elseif E === ADD || E === MUL
265- coeff_isequal (a. coeff, b. coeff) && isequal (a. dict, b. dict)
265+ coeff_isequal (a. coeff, b. coeff; comparator ) && comparator (a. dict, b. dict)
266266 elseif E === DIV
267- isequal (a. num, b. num) && isequal (a. den, b. den)
267+ comparator (a. num, b. num) && comparator (a. den, b. den)
268268 elseif E === POW
269- isequal (a. exp, b. exp) && isequal (a. base, b. base)
269+ comparator (a. exp, b. exp) && comparator (a. base, b. base)
270270 elseif E === TERM
271271 a1 = arguments (a)
272272 a2 = arguments (b)
273- isequal (operation (a), operation (b)) && _allarequal (a1, a2)
273+ comparator (operation (a), operation (b)) && _allarequal (a1, a2; comparator )
274274 else
275275 error_on_type ()
276276 end
@@ -292,8 +292,14 @@ Modifying `Base.isequal` directly breaks numerous tests in `SymbolicUtils.jl` an
292292downstream packages like `ModelingToolkit.jl`, hence the need for this separate
293293function.
294294"""
295- function isequal_with_metadata (a:: BasicSymbolic , b:: BasicSymbolic ):: Bool
296- isequal (a, b) && isequal_with_metadata (metadata (a), metadata (b))
295+ function isequal_with_metadata (a:: BasicSymbolic{T} , b:: BasicSymbolic{S} ):: Bool where {T, S}
296+ a === b && return true
297+
298+ E = exprtype (a)
299+ E === exprtype (b) || return false
300+
301+ T === S || return false
302+ _isequal (a, b, E; comparator = isequal_with_metadata):: Bool && isequal_with_metadata (metadata (a), metadata (b)) || return false
297303end
298304
299305"""
@@ -303,9 +309,9 @@ Compare the metadata of two `BasicSymbolic`s to ensure it is equal, recursively
303309`isequal_with_metadata` to ensure symbolic variables in the metadata also have equal
304310metadata.
305311"""
306- function isequal_with_metadata (a:: Union{AbstractDict, NamedTuple} , b:: Union{AbstractDict, NamedTuple} )
312+ function isequal_with_metadata (a:: NamedTuple , b:: NamedTuple )
313+ a === b && return true
307314 typeof (a) == typeof (b) || return false
308- length (a) == length (b) || return false
309315
310316 for (k, v) in pairs (a)
311317 haskey (b, k) || return false
@@ -320,6 +326,36 @@ function isequal_with_metadata(a::Union{AbstractDict, NamedTuple}, b::Union{Abst
320326 return true
321327end
322328
329+ function isequal_with_metadata (a:: AbstractDict , b:: AbstractDict )
330+ a === b && return true
331+ typeof (a) == typeof (b) || return false
332+ length (a) == length (b) || return false
333+
334+ akeys = collect (keys (a))
335+ avisited = falses (length (akeys))
336+ bkeys = collect (keys (b))
337+ bvisited = falses (length (bkeys))
338+
339+ for k in akeys
340+ idx = findfirst (eachindex (bkeys)) do i
341+ ! bvisited[i] && isequal_with_metadata (k, bkeys[i])
342+ end
343+ idx === nothing && return false
344+ bvisited[idx] = true
345+ isequal_with_metadata (a[k], b[bkeys[idx]]) || return false
346+ end
347+ for (j, k) in enumerate (bkeys)
348+ bvisited[j] && continue
349+ idx = findfirst (eachindex (akeys)) do i
350+ ! avisited[i] && isequal_with_metadata (k, akeys[i])
351+ end
352+ idx === nothing && return false
353+ avisited[idx] = true
354+ isequal_with_metadata (b[k], a[akeys[idx]]) || return false
355+ end
356+ return true
357+ end
358+
323359"""
324360 $(TYPEDSIGNATURES)
325361
@@ -341,6 +377,7 @@ Check if two arrays/tuples are equal by calling `isequal_with_metadata` on each
341377This is to ensure true equality of any symbolic elements, if present.
342378"""
343379function isequal_with_metadata (a:: Union{AbstractArray, Tuple} , b:: Union{AbstractArray, Tuple} )
380+ a === b && return true
344381 typeof (a) == typeof (b) || return false
345382 if a isa AbstractArray
346383 size (a) == size (b) || return false
0 commit comments