diff --git a/src/comparison.jl b/src/comparison.jl index cfa9cb9f..9814b4bb 100644 --- a/src/comparison.jl +++ b/src/comparison.jl @@ -21,29 +21,33 @@ Base.:(==)(x::RationalPoly, α::Nothing) = false Base.:(==)(α::Dict, x::RationalPoly) = false Base.:(==)(x::RationalPoly, α::Dict) = false -function right_term_eq(p::AbstractPolynomial, t) +function right_term_eq(p::AbstractPolynomial, t; comp = (==)) if iszero(p) iszero(t) else # terms/nterms ignore zero terms - nterms(p) == 1 && leading_term(p) == t + nterms(p) == 1 && comp(leading_term(p), t) end end -right_term_eq(p::_APL, t) = right_term_eq(polynomial(p), t) +right_term_eq(p::_APL, t; comp = (==)) = right_term_eq(polynomial(p), t; comp) -left_constant_eq(α, v::AbstractVariable) = false -right_constant_eq(v::AbstractVariable, α) = false -function _term_constant_eq(t::AbstractTermLike, α) +left_constant_eq(α, v::AbstractVariable; comp = (==)) = false +right_constant_eq(v::AbstractVariable, α; comp = (==)) = false +function _term_constant_eq(t::AbstractTermLike, α; comp = (==)) if iszero(t) iszero(α) else - α == coefficient(t) && isconstant(t) + comp(α, coefficient(t)) && isconstant(t) end end -left_constant_eq(α, t::AbstractTermLike) = _term_constant_eq(t, α) -right_constant_eq(t::AbstractTermLike, α) = _term_constant_eq(t, α) -left_constant_eq(α, p::_APL) = right_term_eq(p, α) -right_constant_eq(p::_APL, α) = right_term_eq(p, α) +function left_constant_eq(α, t::AbstractTermLike; comp = (==)) + return _term_constant_eq(t, α; comp) +end +function right_constant_eq(t::AbstractTermLike, α; comp = (==)) + return _term_constant_eq(t, α; comp) +end +left_constant_eq(α, p::_APL; comp = (==)) = right_term_eq(p, α; comp) +right_constant_eq(p::_APL, α; comp = (==)) = right_term_eq(p, α; comp) function Base.:(==)(mono::AbstractMonomial, v::AbstractVariable) return isone(degree(mono)) && variable(mono) == v @@ -58,17 +62,28 @@ function Base.:(==)(mono::AbstractMonomialLike, t::AbstractTerm) return isone(coefficient(t)) && mono == monomial(t) end -function Base.:(==)(t1::AbstractTerm, t2::AbstractTerm) +function _compare_term(t1::AbstractTerm, t2::AbstractTerm, comp) c1 = coefficient(t1) c2 = coefficient(t2) if iszero(c1) iszero(c2) else - c1 == c2 && monomial(t1) == monomial(t2) + comp(c1, c2) && comp(monomial(t1), monomial(t2)) end end + +Base.:(==)(t1::AbstractTerm, t2::AbstractTerm) = _compare_term(t1, t2, ==) Base.:(==)(p::AbstractPolynomial, t::AbstractTerm) = right_term_eq(p, t) Base.:(==)(t::AbstractTerm, p::AbstractPolynomial) = right_term_eq(p, t) +function Base.isequal(t1::AbstractTerm, t2::AbstractTerm) + return _compare_term(t1, t2, isequal) +end +function Base.isequal(p::AbstractPolynomial, t::AbstractTerm) + return right_term_eq(p, t; comp = isequal) +end +function Base.isequal(t::AbstractTerm, p::AbstractPolynomial) + return right_term_eq(p, t; comp = isequal) +end function compare_terms(p1::AbstractPolynomial, p2::AbstractPolynomial, isz, op) i1 = 1 @@ -110,6 +125,9 @@ end function Base.:(==)(p1::AbstractPolynomial, p2::AbstractPolynomial) return compare_terms(p1, p2, iszero, ==) end +function Base.isequal(p1::AbstractPolynomial, p2::AbstractPolynomial) + return compare_terms(p1, p2, iszero, isequal) +end Base.:(==)(p::RationalPoly, q::RationalPoly) = p.num * q.den == q.num * p.den # Solve ambiguity with (::PolyType, ::Any) @@ -117,6 +135,14 @@ Base.:(==)(p::_APL, q::RationalPoly) = p * q.den == q.num Base.:(==)(q::RationalPoly, p::_APL) = p == q Base.:(==)(α, q::RationalPoly) = α * q.den == q.num Base.:(==)(q::RationalPoly, α) = α == q +function Base.isequal(p::RationalPoly, q::RationalPoly) + return isequal(p.num * q.den, q.num * p.den) +end +# Solve ambiguity with (::PolyType, ::Any) +Base.isequal(p::_APL, q::RationalPoly) = isequal(p * q.den, q.num) +Base.isequal(q::RationalPoly, p::_APL) = isequal(p, q) +Base.isequal(α, q::RationalPoly) = isequal(α * q.den, q.num) +Base.isequal(q::RationalPoly, α) = isequal(α, q) # α could be a JuMP affine expression isapproxzero(α; ztol::Real = 0.0) = false diff --git a/test/commutative/comparison.jl b/test/commutative/comparison.jl index b7e12313..a13aa099 100644 --- a/test/commutative/comparison.jl +++ b/test/commutative/comparison.jl @@ -159,4 +159,29 @@ # Springer Science & Business Media, **2013**. end end + + @testset "Comparison with NaN" begin + Mod.@polyvar x y + @testset "$poly1 and $poly2" for (poly1, poly2) in [ + (NaN * x, NaN * x), + (NaN * x + 0, NaN * x + 0), + (NaN * x, NaN * x + 0), + (NaN * x / y, NaN * x / y), + (x / (NaN * y), x / (NaN * y)), + ((NaN * x) / (NaN * y), (NaN * x) / (NaN * y)), + ((NaN * x + 0) / (NaN * y + 0), (NaN * x + 0) / (NaN * y + 0)), + ] + @test poly1 != poly2 + @test poly1 != poly1 + @test isequal(poly1, poly2) + @test isequal(poly1, poly1) + # RationalPoly equality multiplies and thus allocates + if !(poly1 isa RationalPoly) + @test (@allocated poly1 != poly2) == 0 + @test (@allocated poly1 != poly1) == 0 + @test (@allocated isequal(poly1, poly2)) == 0 + @test (@allocated isequal(poly1, poly1)) == 0 + end + end + end end