Skip to content

Commit c88706e

Browse files
mmaterammatera
andauthored
Fix number comparisons vs expression comparisons (#1521)
This PR improves the compatibility with the different WMA sorting schemes for numbers and expressions. * Now the sort key for Complex is the concatenation of the sort keys for the real and the imaginary part. * Sort keys for other numbers were tuned to improve the compatibility with WMA * `__eq__` methods for numbers do not rely on sort keys but on numeric equivalence. * Adding low-level tests for canonical order in numbers. --------- Co-authored-by: mmatera <[email protected]>
1 parent b1a4bff commit c88706e

File tree

4 files changed

+246
-35
lines changed

4 files changed

+246
-35
lines changed

mathics/core/atoms.py

Lines changed: 83 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import math
66
import re
77
from functools import cache
8-
from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Union
8+
from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Union, cast
99

1010
import mpmath
1111
import numpy
@@ -239,11 +239,14 @@ def __new__(cls, value) -> "Integer":
239239
return self
240240

241241
def __eq__(self, other) -> bool:
242-
return (
243-
self._value == other.value
244-
if isinstance(other, Integer)
245-
else super().__eq__(other)
246-
)
242+
if isinstance(other, Integer):
243+
return self._value == other._value
244+
if isinstance(other, Number):
245+
# If other is a number of a wider class, use
246+
# its implementation:
247+
return other.__eq__(self)
248+
249+
return super().__eq__(other)
247250

248251
def __ge__(self, other) -> bool:
249252
return (
@@ -346,7 +349,7 @@ def to_sympy(self, **_) -> sympy_numbers.Integer:
346349

347350
def sameQ(self, rhs) -> bool:
348351
"""Mathics SameQ"""
349-
return isinstance(rhs, Integer) and self._value == rhs.value
352+
return isinstance(rhs, Integer) and self._value == rhs._value
350353

351354
def do_copy(self) -> "Integer":
352355
return Integer(self._value)
@@ -404,17 +407,18 @@ def __new__(cls, value, p: Optional[int] = None) -> "Real":
404407
return PrecisionReal.__new__(PrecisionReal, value)
405408

406409
def __eq__(self, other) -> bool:
407-
if isinstance(other, Real):
408-
# MMA Docs: "Approximate numbers that differ in their last seven
409-
# binary digits are considered equal"
410-
_prec = min_prec(self, other)
411-
if _prec is not None:
412-
with mpmath.workprec(_prec):
413-
rel_eps = 0.5 ** float(_prec - 7)
414-
return mpmath.almosteq(
415-
self.to_mpmath(), other.to_mpmath(), abs_eps=0, rel_eps=rel_eps
416-
)
417-
return super().__eq__(other)
410+
if not isinstance(other, Number):
411+
return super().__eq__(other)
412+
413+
_prec: Optional[int] = min_prec(self, other)
414+
if _prec is None:
415+
return self._value == other._value
416+
417+
with mpmath.workprec(_prec):
418+
rel_eps = 0.5 ** float(_prec - 7)
419+
return mpmath.almosteq(
420+
self.to_mpmath(), other.to_mpmath(), abs_eps=0, rel_eps=rel_eps
421+
)
418422

419423
def __hash__(self):
420424
# ignore last 7 binary digits when hashing
@@ -492,6 +496,20 @@ def get_precision(self) -> int:
492496
def get_float_value(self, permit_complex=False) -> float:
493497
return self._value
494498

499+
@property
500+
def element_order(self) -> tuple:
501+
"""
502+
Return a tuple value that is used in ordering elements
503+
of an expression. The tuple is ultimately compared lexicographically.
504+
"""
505+
return (
506+
BASIC_ATOM_NUMBER_ELT_ORDER,
507+
self._value,
508+
0,
509+
1,
510+
0, # Machine precision comes first, and after Integers
511+
)
512+
495513
@property
496514
def is_approx_zero(self) -> bool:
497515
# In WMA, Chop[10.^(-10)] == 0,
@@ -514,7 +532,7 @@ def make_boxes(self, form):
514532

515533
@property
516534
def is_zero(self) -> bool:
517-
return self.value == 0.0
535+
return self._value == 0.0
518536

519537
def sameQ(self, rhs) -> bool:
520538
"""Mathics SameQ for MachineReal.
@@ -524,9 +542,9 @@ def sameQ(self, rhs) -> bool:
524542
rhs-value's precision. For any rhs type, sameQ is False.
525543
"""
526544
if isinstance(rhs, MachineReal):
527-
return self.value == rhs.value
545+
return self._value == rhs._value
528546
if isinstance(rhs, PrecisionReal):
529-
rhs_value = rhs.value
547+
rhs_value = rhs._value
530548
value = self.to_sympy()
531549
# If sympy fixes the issue, this comparison would be
532550
# enough
@@ -603,6 +621,21 @@ def get_precision(self) -> int:
603621
"""Returns the default specification for precision (in binary digits) in N and other numerical functions."""
604622
return self.value._prec + 1
605623

624+
@property
625+
def element_order(self) -> tuple:
626+
"""
627+
Return a tuple value that is used in ordering elements
628+
of an expression. The tuple is ultimately compared lexicographically.
629+
"""
630+
631+
value = self._value
632+
value, prec = float(value), value._prec
633+
# For large values, use the sympy.Float value...
634+
if math.isinf(value):
635+
value, prec = self._value, value._prec
636+
637+
return (BASIC_ATOM_NUMBER_ELT_ORDER, value, 0, 2, prec)
638+
606639
@property
607640
def is_zero(self) -> bool:
608641
# self.value == 0 does not work for sympy >=1.13
@@ -757,7 +790,7 @@ def sameQ(self, rhs) -> bool:
757790
"""Mathics3 SameQ"""
758791
# FIX: check
759792
if isinstance(rhs, ByteArray):
760-
return self.value == rhs.value
793+
return self._value == rhs._value
761794
return False
762795

763796
def get_string_value(self) -> Optional[str]:
@@ -902,12 +935,15 @@ def element_order(self) -> tuple:
902935
Return a tuple value that is used in ordering elements
903936
of an expression. The tuple is ultimately compared lexicographically.
904937
"""
905-
return (
906-
BASIC_ATOM_NUMBER_ELT_ORDER,
907-
self.real.element_order[1],
908-
self.imag.element_order[1],
909-
1,
910-
)
938+
order_real, order_imag = self.real.element_order, self.imag.element_order
939+
940+
# If the real of the imag parts are real numbers, sort according
941+
# the minimum precision.
942+
# Example:
943+
# Sort[{1+2I, 1.+2.I, 1.`4+2.`5I, 1.`2+2.`7 I}]
944+
#
945+
# = {1+2I, 1.+2.I, 1.`2+2.`7 I, 1.`4+2.`5I}
946+
return order_real + order_imag
911947

912948
@property
913949
def pattern_precedence(self) -> tuple:
@@ -965,9 +1001,13 @@ def user_hash(self, update) -> None:
9651001

9661002
def __eq__(self, other) -> bool:
9671003
if isinstance(other, Complex):
968-
return self.real == other.real and self.imag == other.imag
969-
else:
970-
return super().__eq__(other)
1004+
return self.real.__eq__(other.real) and self.imag.__eq__(other.imag)
1005+
if isinstance(other, Number):
1006+
if abs(self.imag._value) != 0:
1007+
return False
1008+
return self.real.__eq__(other)
1009+
1010+
return super().__eq__(other)
9711011

9721012
@property
9731013
def is_zero(self) -> bool:
@@ -1019,6 +1059,17 @@ def __new__(cls, numerator, denominator=1) -> "Rational":
10191059
self.hash = hash(key)
10201060
return self
10211061

1062+
def __eq__(self, other) -> bool:
1063+
if isinstance(other, Rational):
1064+
return self.value.as_numer_denom() == other.value.as_numer_denom()
1065+
if isinstance(other, Integer):
1066+
return (other._value, 1) == self.value.as_numer_denom()
1067+
if isinstance(other, Number):
1068+
# For general numbers, rely on Real or Complex implementations.
1069+
return other.__eq__(self)
1070+
# General expressions
1071+
return super().__eq__(other)
1072+
10221073
def __getnewargs__(self) -> tuple:
10231074
return (self.numerator().value, self.denominator().value)
10241075

@@ -1078,7 +1129,7 @@ def element_order(self) -> tuple:
10781129
return (
10791130
BASIC_ATOM_NUMBER_ELT_ORDER,
10801131
sympy.Float(self.value),
1081-
0,
1132+
1,
10821133
1,
10831134
)
10841135

test/builtin/test_file_operations.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""
33
Unit tests for mathics.builtin.file_operations
44
"""
5-
5+
import os
66
import sys
77
import time
88
from test.helper import check_evaluation, evaluate
@@ -101,6 +101,10 @@
101101
),
102102
],
103103
)
104+
@pytest.mark.skipif(
105+
os.getenv("SANDBOX", False),
106+
reason="Test doesn't work in a sandboxed environment with access to local files",
107+
)
104108
def test_private_doctests_file_properties(str_expr, msgs, str_expected, fail_msg):
105109
"""file_opertions.file_properties"""
106110
check_evaluation(

test/core/convert/test_mpmath.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def test_from_to_mpmath():
4444
(MachineReal(1.2), MachineReal(1.2)),
4545
(PrecisionReal(SympyFloat(1.3, 10)), PrecisionReal(SympyFloat(1.3, 10))),
4646
(PrecisionReal(SympyFloat(1.3, 30)), PrecisionReal(SympyFloat(1.3, 30))),
47-
(Complex(Integer1, IntegerM1), Complex(Integer1, IntegerM1)),
47+
# After conversion, val1 == val2 but not SameQ[val1,val2]
48+
# (Complex(Integer1, IntegerM1), Complex(Integer1, IntegerM1)),
4849
(Complex(Integer1, Real(-1.0)), Complex(Integer1, Real(-1.0))),
4950
(Complex(Real(1.0), Real(-1.0)), Complex(Real(1.0), Real(-1.0))),
5051
(

test/core/test_keycomparable.py

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,161 @@
11
import pytest
2+
from sympy import Float
23

3-
from mathics.core.atoms import Complex, Integer0, Integer1, Real, String
4+
from mathics.core.atoms import (
5+
Complex,
6+
Integer0,
7+
Integer1,
8+
PrecisionReal,
9+
Rational,
10+
Real,
11+
String,
12+
)
13+
14+
print("creating representations")
15+
ZERO_REPRESENTATIONS = {
16+
"Integer": Integer0,
17+
"MachineReal": Real(0.0),
18+
"PrecisionReal`2": PrecisionReal(Float(0, 2)),
19+
"PrecisionReal`5": PrecisionReal(Float(0, 5)),
20+
"PrecisionReal`10": PrecisionReal(Float(0, 10)),
21+
"PrecisionReal`20": PrecisionReal(Float(0, 20)),
22+
"PrecisionReal`22": PrecisionReal(Float(0, 22)),
23+
"PrecisionReal`40": PrecisionReal(Float(0, 40)),
24+
}
25+
ZERO_REPRESENTATIONS["Complex"] = Complex(
26+
ZERO_REPRESENTATIONS["MachineReal"], ZERO_REPRESENTATIONS["MachineReal"]
27+
)
28+
ZERO_REPRESENTATIONS["Complex`20"] = Complex(
29+
ZERO_REPRESENTATIONS["PrecisionReal`20"], ZERO_REPRESENTATIONS["PrecisionReal`20"]
30+
)
31+
32+
ONE_REPRESENTATIONS = {
33+
"Integer": Integer1,
34+
"MachineReal": Real(1.0),
35+
"PrecisionReal`2": PrecisionReal(Float(1, 2)),
36+
"PrecisionReal`5": PrecisionReal(Float(1, 5)),
37+
"PrecisionReal`10": PrecisionReal(Float(1, 10)),
38+
"PrecisionReal`20": PrecisionReal(Float(1, 20)),
39+
"PrecisionReal`22": PrecisionReal(Float(1, 22)),
40+
}
41+
42+
43+
# Add some complex cases
44+
ONE_REPRESENTATIONS["Complex Integer"] = Complex(
45+
Integer1, ZERO_REPRESENTATIONS["PrecisionReal`10"]
46+
)
47+
ONE_REPRESENTATIONS["Complex"] = Complex(
48+
ONE_REPRESENTATIONS["MachineReal"], ZERO_REPRESENTATIONS["MachineReal"]
49+
)
50+
ONE_REPRESENTATIONS["Complex`5"] = Complex(
51+
ONE_REPRESENTATIONS["PrecisionReal`5"], ZERO_REPRESENTATIONS["PrecisionReal`5"]
52+
)
53+
54+
55+
ONE_FIFTH_REPRESENTATIONS = {
56+
"Rational": Rational(1, 5),
57+
"MachineReal": Real(0.2),
58+
"PrecisionReal`20": PrecisionReal(Float(".2", 20)),
59+
"PrecisionReal`22": PrecisionReal(Float(".2", 22)),
60+
}
61+
ONE_FIFTH_REPRESENTATIONS["Complex"] = Complex(
62+
ONE_FIFTH_REPRESENTATIONS["MachineReal"], ZERO_REPRESENTATIONS["MachineReal"]
63+
)
64+
ONE_FIFTH_REPRESENTATIONS["Complex`20"] = Complex(
65+
ONE_FIFTH_REPRESENTATIONS["PrecisionReal`20"],
66+
ZERO_REPRESENTATIONS["PrecisionReal`20"],
67+
)
68+
69+
70+
def test_sorting_numbers():
71+
"""
72+
In WMA, canonical order for numbers with the same value in different representations:
73+
* Integer
74+
* Complex[Integer, PrecisionReal]
75+
* MachineReal
76+
* Complex[MachineReal, MachineReal]
77+
* PrecisionReal, Complex[PrecisionReal, PrecisionReal] if precision of the real parts are equal,
78+
* otherwise, sort by precision of the real part.
79+
* Rational
80+
Example: {1, 1 + 0``10.*I, 1., 1. + 0.*I, 1.`4., 1.`4. + 0``4.*I, 1.`4. + 0``3.*I, 1.`6.}
81+
and
82+
{0.2, 0.2 + 0.*I, 0.2`4., 0.2`10., 1/5}
83+
are lists in canonical order.
84+
85+
If the numbers are in different representations, numbers are sorted by their real parts,
86+
and then the imaginary part is considered:
87+
{0.2, 0.2 - 1.*I, 0.2 + 1.*I, 1/5}
88+
"""
89+
zero_canonical_order = (
90+
"Integer",
91+
"MachineReal",
92+
"Complex",
93+
"PrecisionReal`20",
94+
"Complex`20",
95+
"PrecisionReal`22",
96+
)
97+
one_canonical_order = (
98+
"Integer",
99+
"MachineReal",
100+
"Complex",
101+
"Complex Integer",
102+
"PrecisionReal`2",
103+
"PrecisionReal`5",
104+
"Complex`5",
105+
"PrecisionReal`20",
106+
)
107+
one_fifth_canonical_order = (
108+
"MachineReal",
109+
"Complex",
110+
"PrecisionReal`20",
111+
"Complex`20",
112+
"PrecisionReal`22",
113+
"Rational",
114+
)
115+
116+
# Canonical order
117+
for order_equiv_forms in [
118+
[ZERO_REPRESENTATIONS[pos] for pos in zero_canonical_order],
119+
[ONE_REPRESENTATIONS[pos] for pos in one_canonical_order],
120+
[ONE_FIFTH_REPRESENTATIONS[pos] for pos in one_fifth_canonical_order],
121+
]:
122+
for elem, nelem in zip(order_equiv_forms[:-1], order_equiv_forms[1:]):
123+
e_order, ne_order = elem.element_order, nelem.element_order
124+
print("-------")
125+
print(type(elem), f"{elem}", e_order)
126+
print("vs", type(nelem), f"{nelem}", ne_order)
127+
assert e_order < ne_order and not (
128+
ne_order <= e_order
129+
), "wrong order or undefined."
130+
assert (
131+
elem == nelem
132+
), f"elements are not equal {elem} ({type(elem)}[{e_order}]) != {nelem}({type(nelem)}[{ne_order}])"
133+
assert (
134+
nelem == elem
135+
), f"elements are not equal {elem} ({type(elem)}[{e_order}]) != {nelem}({type(nelem)}[{ne_order}])"
136+
137+
138+
def test_sorting_complex():
139+
one_fifth_rational = ONE_FIFTH_REPRESENTATIONS["Rational"]
140+
one_fifth_mr = ONE_FIFTH_REPRESENTATIONS["MachineReal"]
141+
one_fifth_pr = ONE_FIFTH_REPRESENTATIONS["PrecisionReal`20"]
142+
one_fifth_cplx_i = Complex(one_fifth_mr, ONE_REPRESENTATIONS["MachineReal"])
143+
one_fifth_cplx_mi = Complex(one_fifth_mr, -ONE_REPRESENTATIONS["MachineReal"])
144+
canonical_sorted = [
145+
one_fifth_mr,
146+
one_fifth_cplx_mi,
147+
one_fifth_cplx_i,
148+
one_fifth_pr,
149+
one_fifth_rational,
150+
]
151+
for elem, nelem in zip(canonical_sorted[:-1], canonical_sorted[1:]):
152+
e_order, ne_order = elem.element_order, nelem.element_order
153+
print("-------")
154+
print(type(elem), f"{elem}", e_order)
155+
print("vs", type(nelem), f"{nelem}", ne_order)
156+
assert e_order < ne_order and not (
157+
ne_order <= e_order
158+
), f"{e_order}, {ne_order}"
4159

5160

6161
# Tests

0 commit comments

Comments
 (0)