Skip to content

Commit 8a641bc

Browse files
fix: fix comparison involving NaN (#334)
* fix: fix comparison involving `NaN` * test: add allocation tests for NaN equality checking * test: add tests for `NaN` equality checks on `RationalPoly`
1 parent 9df2c20 commit 8a641bc

File tree

2 files changed

+64
-13
lines changed

2 files changed

+64
-13
lines changed

src/comparison.jl

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,29 +21,33 @@ Base.:(==)(x::RationalPoly, α::Nothing) = false
2121
Base.:(==)(α::Dict, x::RationalPoly) = false
2222
Base.:(==)(x::RationalPoly, α::Dict) = false
2323

24-
function right_term_eq(p::AbstractPolynomial, t)
24+
function right_term_eq(p::AbstractPolynomial, t; comp = (==))
2525
if iszero(p)
2626
iszero(t)
2727
else
2828
# terms/nterms ignore zero terms
29-
nterms(p) == 1 && leading_term(p) == t
29+
nterms(p) == 1 && comp(leading_term(p), t)
3030
end
3131
end
32-
right_term_eq(p::_APL, t) = right_term_eq(polynomial(p), t)
32+
right_term_eq(p::_APL, t; comp = (==)) = right_term_eq(polynomial(p), t; comp)
3333

34-
left_constant_eq(α, v::AbstractVariable) = false
35-
right_constant_eq(v::AbstractVariable, α) = false
36-
function _term_constant_eq(t::AbstractTermLike, α)
34+
left_constant_eq(α, v::AbstractVariable; comp = (==)) = false
35+
right_constant_eq(v::AbstractVariable, α; comp = (==)) = false
36+
function _term_constant_eq(t::AbstractTermLike, α; comp = (==))
3737
if iszero(t)
3838
iszero(α)
3939
else
40-
α == coefficient(t) && isconstant(t)
40+
comp(α, coefficient(t)) && isconstant(t)
4141
end
4242
end
43-
left_constant_eq(α, t::AbstractTermLike) = _term_constant_eq(t, α)
44-
right_constant_eq(t::AbstractTermLike, α) = _term_constant_eq(t, α)
45-
left_constant_eq(α, p::_APL) = right_term_eq(p, α)
46-
right_constant_eq(p::_APL, α) = right_term_eq(p, α)
43+
function left_constant_eq(α, t::AbstractTermLike; comp = (==))
44+
return _term_constant_eq(t, α; comp)
45+
end
46+
function right_constant_eq(t::AbstractTermLike, α; comp = (==))
47+
return _term_constant_eq(t, α; comp)
48+
end
49+
left_constant_eq(α, p::_APL; comp = (==)) = right_term_eq(p, α; comp)
50+
right_constant_eq(p::_APL, α; comp = (==)) = right_term_eq(p, α; comp)
4751

4852
function Base.:(==)(mono::AbstractMonomial, v::AbstractVariable)
4953
return isone(degree(mono)) && variable(mono) == v
@@ -58,17 +62,28 @@ function Base.:(==)(mono::AbstractMonomialLike, t::AbstractTerm)
5862
return isone(coefficient(t)) && mono == monomial(t)
5963
end
6064

61-
function Base.:(==)(t1::AbstractTerm, t2::AbstractTerm)
65+
function _compare_term(t1::AbstractTerm, t2::AbstractTerm, comp)
6266
c1 = coefficient(t1)
6367
c2 = coefficient(t2)
6468
if iszero(c1)
6569
iszero(c2)
6670
else
67-
c1 == c2 && monomial(t1) == monomial(t2)
71+
comp(c1, c2) && comp(monomial(t1), monomial(t2))
6872
end
6973
end
74+
75+
Base.:(==)(t1::AbstractTerm, t2::AbstractTerm) = _compare_term(t1, t2, ==)
7076
Base.:(==)(p::AbstractPolynomial, t::AbstractTerm) = right_term_eq(p, t)
7177
Base.:(==)(t::AbstractTerm, p::AbstractPolynomial) = right_term_eq(p, t)
78+
function Base.isequal(t1::AbstractTerm, t2::AbstractTerm)
79+
return _compare_term(t1, t2, isequal)
80+
end
81+
function Base.isequal(p::AbstractPolynomial, t::AbstractTerm)
82+
return right_term_eq(p, t; comp = isequal)
83+
end
84+
function Base.isequal(t::AbstractTerm, p::AbstractPolynomial)
85+
return right_term_eq(p, t; comp = isequal)
86+
end
7287

7388
function compare_terms(p1::AbstractPolynomial, p2::AbstractPolynomial, isz, op)
7489
i1 = 1
@@ -110,13 +125,24 @@ end
110125
function Base.:(==)(p1::AbstractPolynomial, p2::AbstractPolynomial)
111126
return compare_terms(p1, p2, iszero, ==)
112127
end
128+
function Base.isequal(p1::AbstractPolynomial, p2::AbstractPolynomial)
129+
return compare_terms(p1, p2, iszero, isequal)
130+
end
113131

114132
Base.:(==)(p::RationalPoly, q::RationalPoly) = p.num * q.den == q.num * p.den
115133
# Solve ambiguity with (::PolyType, ::Any)
116134
Base.:(==)(p::_APL, q::RationalPoly) = p * q.den == q.num
117135
Base.:(==)(q::RationalPoly, p::_APL) = p == q
118136
Base.:(==)(α, q::RationalPoly) = α * q.den == q.num
119137
Base.:(==)(q::RationalPoly, α) = α == q
138+
function Base.isequal(p::RationalPoly, q::RationalPoly)
139+
return isequal(p.num * q.den, q.num * p.den)
140+
end
141+
# Solve ambiguity with (::PolyType, ::Any)
142+
Base.isequal(p::_APL, q::RationalPoly) = isequal(p * q.den, q.num)
143+
Base.isequal(q::RationalPoly, p::_APL) = isequal(p, q)
144+
Base.isequal(α, q::RationalPoly) = isequal* q.den, q.num)
145+
Base.isequal(q::RationalPoly, α) = isequal(α, q)
120146

121147
# α could be a JuMP affine expression
122148
isapproxzero(α; ztol::Real = 0.0) = false

test/commutative/comparison.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,4 +159,29 @@
159159
# Springer Science & Business Media, **2013**.
160160
end
161161
end
162+
163+
@testset "Comparison with NaN" begin
164+
Mod.@polyvar x y
165+
@testset "$poly1 and $poly2" for (poly1, poly2) in [
166+
(NaN * x, NaN * x),
167+
(NaN * x + 0, NaN * x + 0),
168+
(NaN * x, NaN * x + 0),
169+
(NaN * x / y, NaN * x / y),
170+
(x / (NaN * y), x / (NaN * y)),
171+
((NaN * x) / (NaN * y), (NaN * x) / (NaN * y)),
172+
((NaN * x + 0) / (NaN * y + 0), (NaN * x + 0) / (NaN * y + 0)),
173+
]
174+
@test poly1 != poly2
175+
@test poly1 != poly1
176+
@test isequal(poly1, poly2)
177+
@test isequal(poly1, poly1)
178+
# RationalPoly equality multiplies and thus allocates
179+
if !(poly1 isa RationalPoly)
180+
@test (@allocated poly1 != poly2) == 0
181+
@test (@allocated poly1 != poly1) == 0
182+
@test (@allocated isequal(poly1, poly2)) == 0
183+
@test (@allocated isequal(poly1, poly1)) == 0
184+
end
185+
end
186+
end
162187
end

0 commit comments

Comments
 (0)