Skip to content

Commit 9cd2fcf

Browse files
authored
Implement 3-way comparator (#609)
* Implement 3-way comparator * c o v e r a g e * Some new stuff * Fix the tests * Test it * Do that thing * Don't forget the negation * Dead
1 parent fbd2865 commit 9cd2fcf

File tree

6 files changed

+271
-1
lines changed

6 files changed

+271
-1
lines changed

src/basilisp/core.lpy

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1061,8 +1061,37 @@
10611061
[frac]
10621062
(.-denominator frac))
10631063

1064+
(defn compare
1065+
"Return either -1, 0, or 1 to indicate the relationship between x and y.
1066+
1067+
This is a 3-way comparator commonly used in Java-derived systems. Python does not
1068+
typically use 3-way comparators, so this function convert's Python's `__lt__` and
1069+
`__gt__` method returns into one of the 3-way comparator return values.
1070+
1071+
Comparisons are generally only valid between homogeneous objects, with the exception
1072+
of numbers which can be compared regardless of type. Most scalar value types
1073+
(including nil, booleans, numbers, strings, keywords, and symbols) are comparable.
1074+
nil compares less than all values except itself. NaN compares equal to all numbers,
1075+
including itself. Strings are compared lexicographically. Symbols and keywords are
1076+
sorted first on their namespace, if they have one, and then on their name. Symbols
1077+
and keywords with namespaces always sort ahead of those without. Symbols cannot be
1078+
compared to keywords.
1079+
1080+
Of the built in collection types, only vectors can be compared. Vectors are compared
1081+
first by their length and then element-wise.
1082+
1083+
Other collections such as maps, sequences, and sets cannot be compared.
1084+
1085+
Python objects supporting `__lt__` and `__gt__` can generally be compared."
1086+
[x y]
1087+
(basilisp.lang.runtime/compare x y))
1088+
10641089
(defn sort
1065-
"Return a sorted sequence of the elements from coll."
1090+
"Return a sorted sequence of the elements from coll.
1091+
1092+
Unlike in Clojure, this function does not use `compare` directly. Instead, the
1093+
heuristics that `compare` uses to produce three-way comparator return values
1094+
are used to guide sorting."
10661095
([coll]
10671096
(basilisp.lang.runtime/sort coll))
10681097
([cmp coll]

src/basilisp/lang/keyword.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import threading
2+
from functools import total_ordering
23
from typing import Iterable, Optional
34

45
from basilisp.lang import map as lmap
@@ -8,6 +9,7 @@
89
_INTERN: IPersistentMap[int, "Keyword"] = lmap.PersistentMap.empty()
910

1011

12+
@total_ordering
1113
class Keyword(ILispObject):
1214
__slots__ = ("_name", "_ns", "_hash")
1315

@@ -38,6 +40,19 @@ def __eq__(self, other):
3840
def __hash__(self):
3941
return self._hash
4042

43+
def __lt__(self, other):
44+
if other is None: # pragma: no cover
45+
return False
46+
if not isinstance(other, Keyword):
47+
return NotImplemented
48+
if self._ns is None and other._ns is None:
49+
return self._name < other._name
50+
if self._ns is None:
51+
return True
52+
if other._ns is None:
53+
return False
54+
return self._ns < other._ns or self._name < other._name
55+
4156
def __call__(self, m: IAssociative, default=None):
4257
try:
4358
return m.val_at(self, default)

src/basilisp/lang/runtime.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
import decimal
23
import functools
34
import importlib
45
import inspect
@@ -1249,6 +1250,55 @@ def quotient(num, div) -> LispNumber:
12491250
return math.trunc(num / div)
12501251

12511252

1253+
@functools.singledispatch
1254+
def compare(x, y) -> int:
1255+
"""Return either -1, 0, or 1 to indicate the relationship between x and y.
1256+
1257+
This is a 3-way comparator commonly used in Java-derived systems. Python does not
1258+
typically use 3-way comparators, so this function convert's Python's `__lt__` and
1259+
`__gt__` method returns into one of the 3-way comparator return values."""
1260+
if y is None:
1261+
assert x is not None, "x cannot be nil"
1262+
return 1
1263+
return (x > y) - (x < y)
1264+
1265+
1266+
@compare.register(type(None))
1267+
def _compare_nil(_: None, y) -> int:
1268+
# nil is less than all values, except itself.
1269+
return 0 if y is None else -1
1270+
1271+
1272+
@compare.register(decimal.Decimal)
1273+
def _compare_decimal(x: decimal.Decimal, y) -> int:
1274+
# Decimal instances will not compare with float("nan"), so we need a special case
1275+
if isinstance(y, float):
1276+
return -compare(y, x) # pylint: disable=arguments-out-of-order
1277+
return (x > y) - (x < y)
1278+
1279+
1280+
@compare.register(float)
1281+
def _compare_float(x, y) -> int:
1282+
if y is None:
1283+
return 1
1284+
if math.isnan(x):
1285+
return 0
1286+
return (x > y) - (x < y)
1287+
1288+
1289+
@compare.register(IPersistentSet)
1290+
def _compare_sets(x: IPersistentSet, y) -> int:
1291+
# Sets are not comparable (because there is no total ordering between sets).
1292+
# However, in Python comparison is done using __lt__ and __gt__, which AbstractSet
1293+
# inconveniently also uses as part of it's API for comparing sets with subset and
1294+
# superset relationships. To "break" that, we just override the comparison method.
1295+
# One consequence of this is that it may be possible to sort a collection of sets,
1296+
# since `compare` isn't actually consulted in sorting.
1297+
raise TypeError(
1298+
f"cannot compare instances of '{type(x).__name__}' and '{type(y).__name__}'"
1299+
)
1300+
1301+
12521302
def sort(coll, f=None) -> Optional[ISeq]:
12531303
"""Return a sorted sequence of the elements in coll. If a comparator
12541304
function f is provided, compare elements in coll using f."""

src/basilisp/lang/symbol.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from functools import total_ordering
12
from typing import Optional
23

34
from basilisp.lang.interfaces import ILispObject, IPersistentMap, IWithMeta
45
from basilisp.lang.obj import lrepr
56
from basilisp.lang.util import munge
67

78

9+
@total_ordering
810
class Symbol(ILispObject, IWithMeta):
911
__slots__ = ("_name", "_ns", "_meta", "_hash")
1012

@@ -56,6 +58,19 @@ def __eq__(self, other):
5658
def __hash__(self):
5759
return self._hash
5860

61+
def __lt__(self, other):
62+
if other is None: # pragma: no cover
63+
return False
64+
if not isinstance(other, Symbol):
65+
return NotImplemented
66+
if self._ns is None and other._ns is None:
67+
return self._name < other._name
68+
if self._ns is None:
69+
return True
70+
if other._ns is None:
71+
return False
72+
return self._ns < other._ns or self._name < other._name
73+
5974

6075
def symbol(name: str, ns: Optional[str] = None, meta=None) -> Symbol:
6176
"""Create a new symbol."""

src/basilisp/lang/vector.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from functools import total_ordering
12
from typing import Iterable, Optional, Sequence, TypeVar, Union
23

34
from pyrsistent import PVector, pvector # noqa # pylint: disable=unused-import
@@ -74,6 +75,7 @@ def to_persistent(self) -> "PersistentVector[T]":
7475
return PersistentVector(self._inner.persistent())
7576

7677

78+
@total_ordering
7779
class PersistentVector(
7880
IPersistentVector[T], IEvolveableCollection[TransientVector], ILispObject, IWithMeta
7981
):
@@ -116,6 +118,15 @@ def __iter__(self):
116118
def __len__(self):
117119
return len(self._inner)
118120

121+
def __lt__(self, other):
122+
if other is None: # pragma: no cover
123+
return False
124+
if not isinstance(other, PersistentVector):
125+
return NotImplemented
126+
if len(self) != len(other):
127+
return len(self) < len(other)
128+
return any(x < y for x, y in zip(self, other))
129+
119130
def _lrepr(self, **kwargs) -> str:
120131
return _seq_lrepr(self._inner, "[", "]", meta=self._meta, **kwargs)
121132

tests/basilisp/test_core_fns.lpy

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,156 @@
7575
(is (= v1 v2))
7676
(is (= {:tag vector} (meta v2))))))))
7777

78+
(deftest compare-test
79+
(testing "nil"
80+
(are [res x y] (= res (compare x y))
81+
0 nil nil
82+
-1 nil "a"
83+
1 "a" nil))
84+
85+
(testing "boolean"
86+
(are [res x y] (= res (compare x y))
87+
0 true true
88+
0 false false
89+
-1 false true
90+
1 true false))
91+
92+
(testing "numbers"
93+
(are [res x y] (= res (compare x y))
94+
1 ##NaN nil
95+
-1 nil ##NaN
96+
97+
0 ##NaN ##NaN
98+
0 ##NaN 1
99+
0 ##NaN 3.14
100+
0 ##NaN 22/7
101+
0 ##NaN 3.07M
102+
0 3.07M ##NaN
103+
0 1 ##NaN
104+
0 3.14 ##NaN
105+
0 3.14 ##NaN
106+
0 22/7 ##NaN
107+
108+
0 1 1
109+
0 1 1.0
110+
0 1 1M
111+
0 2 10/5
112+
0 1.0 1
113+
0 1.0 1M
114+
0 1.0 1.0
115+
0 2.0 10/5
116+
0 1M 1M
117+
0 1M 1
118+
0 1M 1.0
119+
0 2M 10/5
120+
0 10/5 2
121+
0 10/5 2.0
122+
0 10/5 2M
123+
0 10/5 10/5
124+
125+
1 3 1
126+
1 3 1.07
127+
1 3 2/5
128+
1 3 1.07M
129+
1 3.33 1.07
130+
1 3.33 1
131+
1 3.33 2/5
132+
1 3.33 1.07M
133+
1 3.33M 1.07M
134+
1 3.33M 1
135+
1 3.33M 1.07
136+
1 3.33M 10/5
137+
1 10/5 1
138+
1 10/5 1.0
139+
1 10/5 1.07M
140+
1 10/5 2/5
141+
142+
-1 1 3
143+
-1 1.07 3
144+
-1 2/5 3
145+
-1 1.07M 3
146+
-1 1.07 3.33
147+
-1 1 3.33
148+
-1 2/5 3.33
149+
-1 1.07M 3.33
150+
-1 1.07M 3.33M
151+
-1 1 3.33M
152+
-1 1.07 3.33M
153+
-1 10/5 3.33M
154+
-1 1 10/5
155+
-1 1.07 10/5
156+
-1 1.07M 10/5
157+
-1 2/5 10/5))
158+
159+
(testing "strings"
160+
(are [res x y] (= res (compare x y))
161+
0 "a" "a"
162+
-1 "a" "b"
163+
1 "b" "a"))
164+
165+
(testing "keywords"
166+
(are [res x y] (= res (compare x y))
167+
0 :a :a
168+
-1 :a :b
169+
1 :b :a
170+
171+
1 :a/b :b
172+
-1 :b :a/b
173+
174+
-1 :a/b :a/c
175+
-1 :a/b :b/b
176+
0 :a/b :a/b
177+
1 :a/c :a/b
178+
1 :b/b :a/b))
179+
180+
(testing "symbols"
181+
(are [res x y] (= res (compare x y))
182+
0 'a 'a
183+
-1 'a 'b
184+
1 'b 'a
185+
186+
1 'a/b 'b
187+
-1 'b 'a/b
188+
189+
-1 'a/b 'a/c
190+
-1 'a/b 'b/b
191+
0 'a/b 'a/b
192+
1 'a/c 'a/b
193+
1 'b/b 'a/b))
194+
195+
(testing "vectors"
196+
(are [res x y] (= res (compare x y))
197+
0 [] []
198+
-1 [] [1]
199+
1 [1] []
200+
201+
0 [0 1 2] [0 1 2]
202+
-1 [0 1 2] [1 1 2]
203+
1 [1 1 2] [0 1 2]))
204+
205+
(testing "un-comparables"
206+
(are [x y] (thrown? python/TypeError (compare x y))
207+
:a 'a
208+
:a "a"
209+
'a "a"
210+
'a :a
211+
"a" 'a
212+
"a" :a
213+
214+
[] '()
215+
216+
'() '()
217+
'(1 2 3) '(1 2 3)
218+
'(1 2 3) '(4 5 6)
219+
220+
#{} #{}
221+
#{1 2 3} #{1 2 3}
222+
#{1 2 3} #{4 5 6}
223+
224+
{} {}
225+
{:a 1 :b 2} {:a 1 :b 2}
226+
{:a 1 :b 2} {:c 3 :d 4})))
227+
78228
;;;;;;;;;;;;;;;;;;;;;;;;;;
79229
;; Collection Functions ;;
80230
;;;;;;;;;;;;;;;;;;;;;;;;;;

0 commit comments

Comments
 (0)