From 4cddccd4bdbd3ce0ecd6f4a90e623c03f2b52074 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 8 Aug 2025 18:19:42 -0400 Subject: [PATCH 1/2] fix: fix comparison involving `NaN` --- src/comp.jl | 9 ++++++--- test/comp.jl | 6 ++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/comp.jl b/src/comp.jl index 62ccfbb..16f8697 100644 --- a/src/comp.jl +++ b/src/comp.jl @@ -58,25 +58,28 @@ end (==)(x::MonomialVector, mv::AbstractVector) = x == monomial_vector(mv) # Comparison of Term -function (==)(p::Polynomial{V,M}, q::Polynomial{V,M}) where {V,M} +function _compare(p::Polynomial{V,M}, q::Polynomial{V,M}, comparator) where {V,M} # terms should be sorted and without zeros if MP.nterms(p) != MP.nterms(q) return false end for i in eachindex(p.a) - if p.x[i] != q.x[i] + if !comparator(p.x[i], q.x[i]) # There should not be zero terms @assert p.a[i] != 0 @assert q.a[i] != 0 return false end - if p.a[i] != q.a[i] + if !comparator(p.a[i], q.a[i]) return false end end return true end +(==)(p::Polynomial{V, M}, q::Polynomial{V, M}) where {V, M} = _compare(p, q, (==)) +Base.isequal(p::Polynomial{V, M}, q::Polynomial{V, M}) where {V, M} = _compare(p, q, isequal) + function Base.isapprox( p::Polynomial{V,M,S}, q::Polynomial{V,M,T}; diff --git a/test/comp.jl b/test/comp.jl index 94c526b..08da364 100644 --- a/test/comp.jl +++ b/test/comp.jl @@ -97,3 +97,9 @@ end _test_monomials([x, y], 1:2, [y, x, y^2, x * y, x^2]) _test_monomials([x, y], [0, 1, 3], [1, y, x, y^3, x*y^2, x^2*y, x^3]) end + +@testset "Comparison with NaN" begin + @polyvar p + @test (NaN + p) != (NaN + p) + @test isequal(NaN + p, NaN + p) +end From 2b9ca2bdebcc88c830a1f64013dce42ded252fdd Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 9 Aug 2025 10:34:15 -0400 Subject: [PATCH 2/2] test: add allocation tests to NaN equality checking --- test/comp.jl | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/test/comp.jl b/test/comp.jl index 08da364..6486cb3 100644 --- a/test/comp.jl +++ b/test/comp.jl @@ -100,6 +100,14 @@ end @testset "Comparison with NaN" begin @polyvar p - @test (NaN + p) != (NaN + p) - @test isequal(NaN + p, NaN + p) + poly1 = NaN + p + poly2 = NaN + p + @test poly1 != poly2 + @test (@allocated poly1 != poly2) == 0 + @test poly1 != poly1 + @test (@allocated poly1 != poly1) == 0 + @test isequal(poly1, poly2) + @test (@allocated isequal(poly1, poly2)) == 0 + @test isequal(poly1, poly1) + @test (@allocated isequal(poly1, poly1)) == 0 end