Skip to content

Commit 4a5f0f6

Browse files
fix: fix comparison involving NaN
1 parent 9df2c20 commit 4a5f0f6

File tree

2 files changed

+37
-13
lines changed

2 files changed

+37
-13
lines changed

src/comparison.jl

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,29 +21,29 @@ Base.:(==)(x::RationalPoly, α::Nothing) = false
2121
Base.:(==)(α::Dict, x::RationalPoly) = false
2222
Base.:(==)(x::RationalPoly, α::Dict) = false
2323

24-
function right_term_eq(p::AbstractPolynomial, t)
24+
function right_term_eq(p::AbstractPolynomial, t; comp = (==))
2525
if iszero(p)
2626
iszero(t)
2727
else
2828
# terms/nterms ignore zero terms
29-
nterms(p) == 1 && leading_term(p) == t
29+
nterms(p) == 1 && comp(leading_term(p), t)
3030
end
3131
end
32-
right_term_eq(p::_APL, t) = right_term_eq(polynomial(p), t)
32+
right_term_eq(p::_APL, t; comp = (==)) = right_term_eq(polynomial(p), t; comp)
3333

34-
left_constant_eq(α, v::AbstractVariable) = false
35-
right_constant_eq(v::AbstractVariable, α) = false
36-
function _term_constant_eq(t::AbstractTermLike, α)
34+
left_constant_eq(α, v::AbstractVariable; comp = (==)) = false
35+
right_constant_eq(v::AbstractVariable, α; comp = (==)) = false
36+
function _term_constant_eq(t::AbstractTermLike, α; comp = (==))
3737
if iszero(t)
3838
iszero(α)
3939
else
40-
α == coefficient(t) && isconstant(t)
40+
comp(α, coefficient(t)) && isconstant(t)
4141
end
4242
end
43-
left_constant_eq(α, t::AbstractTermLike) = _term_constant_eq(t, α)
44-
right_constant_eq(t::AbstractTermLike, α) = _term_constant_eq(t, α)
45-
left_constant_eq(α, p::_APL) = right_term_eq(p, α)
46-
right_constant_eq(p::_APL, α) = right_term_eq(p, α)
43+
left_constant_eq(α, t::AbstractTermLike; comp = (==)) = _term_constant_eq(t, α; comp)
44+
right_constant_eq(t::AbstractTermLike, α; comp = (==)) = _term_constant_eq(t, α; comp)
45+
left_constant_eq(α, p::_APL; comp = (==)) = right_term_eq(p, α; comp)
46+
right_constant_eq(p::_APL, α; comp = (==)) = right_term_eq(p, α; comp)
4747

4848
function Base.:(==)(mono::AbstractMonomial, v::AbstractVariable)
4949
return isone(degree(mono)) && variable(mono) == v
@@ -58,17 +58,22 @@ function Base.:(==)(mono::AbstractMonomialLike, t::AbstractTerm)
5858
return isone(coefficient(t)) && mono == monomial(t)
5959
end
6060

61-
function Base.:(==)(t1::AbstractTerm, t2::AbstractTerm)
61+
function _compare_term(t1::AbstractTerm, t2::AbstractTerm, comp)
6262
c1 = coefficient(t1)
6363
c2 = coefficient(t2)
6464
if iszero(c1)
6565
iszero(c2)
6666
else
67-
c1 == c2 && monomial(t1) == monomial(t2)
67+
comp(c1, c2) && comp(monomial(t1), monomial(t2))
6868
end
6969
end
70+
71+
Base.:(==)(t1::AbstractTerm, t2::AbstractTerm) = _compare_term(t1, t2, ==)
7072
Base.:(==)(p::AbstractPolynomial, t::AbstractTerm) = right_term_eq(p, t)
7173
Base.:(==)(t::AbstractTerm, p::AbstractPolynomial) = right_term_eq(p, t)
74+
Base.isequal(t1::AbstractTerm, t2::AbstractTerm) = _compare_term(t1, t2, isequal)
75+
Base.isequal(p::AbstractPolynomial, t::AbstractTerm) = right_term_eq(p, t; comp = isequal)
76+
Base.isequal(t::AbstractTerm, p::AbstractPolynomial) = right_term_eq(p, t; comp = isequal)
7277

7378
function compare_terms(p1::AbstractPolynomial, p2::AbstractPolynomial, isz, op)
7479
i1 = 1
@@ -110,13 +115,22 @@ end
110115
function Base.:(==)(p1::AbstractPolynomial, p2::AbstractPolynomial)
111116
return compare_terms(p1, p2, iszero, ==)
112117
end
118+
function Base.isequal(p1::AbstractPolynomial, p2::AbstractPolynomial)
119+
return compare_terms(p1, p2, iszero, isequal)
120+
end
113121

114122
Base.:(==)(p::RationalPoly, q::RationalPoly) = p.num * q.den == q.num * p.den
115123
# Solve ambiguity with (::PolyType, ::Any)
116124
Base.:(==)(p::_APL, q::RationalPoly) = p * q.den == q.num
117125
Base.:(==)(q::RationalPoly, p::_APL) = p == q
118126
Base.:(==)(α, q::RationalPoly) = α * q.den == q.num
119127
Base.:(==)(q::RationalPoly, α) = α == q
128+
Base.isequal(p::RationalPoly, q::RationalPoly) = isequal(p.num * q.den, q.num * p.den)
129+
# Solve ambiguity with (::PolyType, ::Any)
130+
Base.isequal(p::_APL, q::RationalPoly) = isequal(p * q.den, q.num)
131+
Base.isequal(q::RationalPoly, p::_APL) = isequal(p, q)
132+
Base.isequal(α, q::RationalPoly) = isequal* q.den, q.num)
133+
Base.isequal(q::RationalPoly, α) = isequal(α, q)
120134

121135
# α could be a JuMP affine expression
122136
isapproxzero(α; ztol::Real = 0.0) = false

test/commutative/comparison.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,4 +159,14 @@
159159
# Springer Science & Business Media, **2013**.
160160
end
161161
end
162+
163+
@testset "Comparison with NaN" begin
164+
Mod.@polyvar x
165+
@test isequal(NaN * x, NaN * x)
166+
@test isequal(NaN * x + 0, NaN * x + 0)
167+
@test isequal(NaN * x, NaN * x + 0)
168+
@test !=(NaN * x, NaN * x)
169+
@test !=(NaN * x + 0, NaN * x + 0)
170+
@test !=(NaN * x, NaN * x + 0)
171+
end
162172
end

0 commit comments

Comments
 (0)