Skip to content

Commit 403bff7

Browse files
authored
Support sort-by maps and boolean comparison fns. (#709)
Hi, could you please review update to support sorting maps with `sort-by`, and also accepting boolean comparator fns in the same. It partly addresses #708. Basically, maps were delegating sorting to the immutables libraries, which returns the keys of the maps when sorting; this patch converts them into a sequence of key/value pairs, which is what Clojure uses to return the result. Boolean comparators were not working and are now converted to 3 way comparator using the new `_fn_to_comparator` fn. Since now the sorting is done on key/value vector pairs, it came to my attention that the `vector.__lt__` which the sorting is delegating to, was not working as expected, and thus updated to work according to how Clojure thinks about vector comparison. I've added test for all of the above. I haven't touched `sort` much yet, wanted first to discuss this patch first. Thanks --------- Co-authored-by: ikappaki <[email protected]>
1 parent d425923 commit 403bff7

File tree

5 files changed

+132
-23
lines changed

5 files changed

+132
-23
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
* Fix issue with `case` evaluating all of its clauses expressions (#699).
99
* Fix issue with relative paths dropping their first character on MS-Windows (#703).
1010
* Fix incompatibility with `(str nil)` returning "nil" (#706).
11+
* Fix `sort-by` support for maps and boolean comparator fns (#709).
1112

1213
## [v0.1.0a2]
1314
### Added

src/basilisp/lang/runtime.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import itertools
77
import logging
88
import math
9+
import numbers
910
import re
1011
import sys
1112
import threading
@@ -1375,39 +1376,64 @@ def _compare_sets(x: IPersistentSet, y) -> int:
13751376
def sort(coll, f=None) -> Optional[ISeq]:
13761377
"""Return a sorted sequence of the elements in coll. If a comparator
13771378
function f is provided, compare elements in coll using f."""
1379+
if isinstance(coll, IPersistentMap):
1380+
coll = lseq.to_seq(coll)
13781381
return lseq.sequence(sorted(coll, key=Maybe(f).map(functools.cmp_to_key).value))
13791382

13801383

1381-
def sort_by(keyfn, coll, cmp=None) -> Optional[ISeq]:
1382-
"""Return a sorted sequence of the elements in coll. If a comparator
1383-
function f is provided, compare elements in coll using f."""
1384-
if cmp is not None:
1384+
def _fn_to_comparator(f):
1385+
"""Coerce F comparator fn to a 3 way comparator fn."""
1386+
1387+
if f == compare: # pylint: disable=comparison-with-callable
1388+
return f
13851389

1386-
class key:
1387-
__slots__ = ("obj",)
1390+
def cmp(x, y):
1391+
r = f(x, y)
1392+
if isinstance(r, numbers.Number) and not isinstance(r, bool):
1393+
return r
1394+
elif r:
1395+
return -1
1396+
elif f(y, x):
1397+
return 1
1398+
else:
1399+
return 0
13881400

1389-
def __init__(self, obj):
1390-
self.obj = obj
1401+
return cmp
13911402

1392-
def __lt__(self, other):
1393-
return cmp(keyfn(self.obj), keyfn(other.obj)) < 0
13941403

1395-
def __gt__(self, other):
1396-
return cmp(keyfn(self.obj), keyfn(other.obj)) > 0
1404+
def sort_by(keyfn, coll, cmp=compare) -> Optional[ISeq]:
1405+
"""Return a sorted sequence of the elements in coll. If a
1406+
comparator function cmp is provided, compare elements in coll
1407+
using cmp or use the `compare` fn if not.
13971408
1398-
def __eq__(self, other):
1399-
return cmp(keyfn(self.obj), keyfn(other.obj)) == 0
1409+
The comparator fn can be either a boolean or 3-way comparison fn."""
1410+
if isinstance(coll, IPersistentMap):
1411+
coll = lseq.to_seq(coll)
14001412

1401-
def __le__(self, other):
1402-
return cmp(keyfn(self.obj), keyfn(other.obj)) <= 0
1413+
comparator = _fn_to_comparator(cmp)
14031414

1404-
def __ge__(self, other):
1405-
return cmp(keyfn(self.obj), keyfn(other.obj)) >= 0
1415+
class key:
1416+
__slots__ = ("obj",)
14061417

1407-
__hash__ = None # type: ignore
1418+
def __init__(self, obj):
1419+
self.obj = obj
14081420

1409-
else:
1410-
key = keyfn # type: ignore
1421+
def __lt__(self, other):
1422+
return comparator(keyfn(self.obj), keyfn(other.obj)) < 0
1423+
1424+
def __gt__(self, other):
1425+
return comparator(keyfn(self.obj), keyfn(other.obj)) > 0
1426+
1427+
def __eq__(self, other):
1428+
return comparator(keyfn(self.obj), keyfn(other.obj)) == 0
1429+
1430+
def __le__(self, other):
1431+
return comparator(keyfn(self.obj), keyfn(other.obj)) <= 0
1432+
1433+
def __ge__(self, other):
1434+
return comparator(keyfn(self.obj), keyfn(other.obj)) >= 0
1435+
1436+
__hash__ = None # type: ignore
14111437

14121438
return lseq.sequence(sorted(coll, key=key))
14131439

src/basilisp/lang/vector.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,27 @@ def __len__(self):
119119
return len(self._inner)
120120

121121
def __lt__(self, other):
122+
"""Return true if the `self` vector is shorter than the
123+
`other` vector, or the first unequal element in `self` when
124+
iterating from left to right is less than the corresponding
125+
`other` element.
126+
127+
This is to support the comparing and sorting operations of
128+
vectors in Clojure."""
129+
122130
if other is None: # pragma: no cover
123131
return False
124132
if not isinstance(other, PersistentVector):
125133
return NotImplemented
126134
if len(self) != len(other):
127135
return len(self) < len(other)
128-
return any(x < y for x, y in zip(self, other))
136+
137+
for x, y in zip(self, other):
138+
if x < y:
139+
return True
140+
elif y < x:
141+
return False
142+
return False
129143

130144
def _lrepr(self, **kwargs) -> str:
131145
return _seq_lrepr(self._inner, "[", "]", meta=self._meta, **kwargs)

tests/basilisp/test_core_fns.lpy

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@
200200

201201
0 [0 1 2] [0 1 2]
202202
-1 [0 1 2] [1 1 2]
203+
1 [2 55] [1 99]
203204
1 [1 1 2] [0 1 2]))
204205

205206
(testing "un-comparables"
@@ -979,7 +980,42 @@
979980
(let [cmp (fn [v1 v2] (- v2 v1))]
980981
(is (= '() (sort-by count cmp [])))
981982
(is (= '([1 2 3] [5 5] [:a])
982-
(sort-by count cmp [[1 2 3] [:a] [5 5]]))))))
983+
(sort-by count cmp [[1 2 3] [:a] [5 5]])))))
984+
985+
(testing "sorting vectors"
986+
(are [res v] (= res (sort-by identity v))
987+
'([1] [3] [1 2]) [[1] [1 2] [3]]
988+
'([1] [1 2] [1 3]) [[1 3] [1] [1 2]]
989+
'([1] [0 1] [0 1]) [[0 1] [1] [0 1] ])
990+
991+
;; taken from clojuredocs
992+
(is (= '([1 2] [2 2] [2 3]) (sort-by first [[1 2] [2 2] [2 3]])))
993+
(is (= '([2 2] [2 3] [1 2]) (sort-by first > [[1 2] [2 2] [2 3]]) )))
994+
995+
(testing "sorting maps"
996+
(are [res v] (= res (sort-by identity v))
997+
'([:3 18] [:5 28] [:9 23]) {:9 23 :3 18 :5 28})
998+
999+
;; taken from clojuredocs
1000+
(is (= '({:rank 1} {:rank 2} {:rank 3}) (sort-by :rank [{:rank 2} {:rank 3} {:rank 1}])))
1001+
(let [x [{:foo 2 :bar 11}
1002+
{:bar 99 :foo 1}
1003+
{:bar 55 :foo 2}
1004+
{:foo 1 :bar 77}]
1005+
order [55 77 99 11]]
1006+
(is (= '({:foo 1, :bar 77} {:bar 99, :foo 1} {:foo 2, :bar 11} {:bar 55, :foo 2})
1007+
(sort-by (juxt :foo :bar) x)))
1008+
(is (= '({:bar 55, :foo 2} {:foo 1, :bar 77} {:bar 99, :foo 1} {:foo 2, :bar 11})
1009+
(sort-by
1010+
#((into {} (map-indexed (fn [i e] [e i]) order)) (:bar %))
1011+
x))))
1012+
(is (= '([:foo 7] [:baz 5] [:bar 3]) (sort-by val > {:foo 7, :bar 3, :baz 5})))
1013+
(is (= '({:value 1, :label "a"} {:value 2, :label "b"} {:value 3, :label "c"})
1014+
(sort-by :value [{:value 1 :label "a"} {:value 3 :label "c"} {:value 2 :label "b"}])))
1015+
(is (= '({:value 3 :label "c"} {:value 2, :label "b"} {:value 1, :label "a"})
1016+
(sort-by :value #(> %1 %2) [{:value 1 :label "a"} {:value 3 :label "c"} {:value 2 :label "b"}])))
1017+
(is (= '({:label "c"} {:value 1, :label "a"} {:value 2, :label "b"})
1018+
(sort-by :value [{:value 1 :label "a"} {:label "c"} {:value 2 :label "b"}])))))
9831019

9841020
(deftest zipmap-test
9851021
(are [x y z] (= x (zipmap y z))

tests/basilisp/vector_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,38 @@ def test_vector_with_meta():
162162
assert vec3.meta == lmap.m(tag=keyword("macro"))
163163

164164

165+
@pytest.mark.parametrize(
166+
"result,v1,v2",
167+
[
168+
(False, vec.v(), vec.v()),
169+
(False, vec.v(1), vec.v(1)),
170+
(False, vec.v(1, 2), vec.v(1, 2)),
171+
(True, vec.v(1, 2), vec.v(1, 3)),
172+
(False, vec.v(1, 3), vec.v(1, 2)),
173+
(True, vec.v(1, 2), vec.v(1, 2, 3)),
174+
(False, vec.v(1, 2, 3), vec.v(1, 2)),
175+
],
176+
)
177+
def test_vector_less_than(result, v1, v2):
178+
assert result == (v1 < v2)
179+
180+
181+
@pytest.mark.parametrize(
182+
"result,v1,v2",
183+
[
184+
(False, vec.v(), vec.v()),
185+
(False, vec.v(1), vec.v(1)),
186+
(False, vec.v(1, 2), vec.v(1, 2)),
187+
(False, vec.v(1, 2), vec.v(1, 3)),
188+
(True, vec.v(1, 3), vec.v(1, 2)),
189+
(False, vec.v(1, 2), vec.v(1, 2, 3)),
190+
(True, vec.v(1, 2, 3), vec.v(1, 2)),
191+
],
192+
)
193+
def test_vector_greater_than(result, v1, v2):
194+
assert result == (v1 > v2)
195+
196+
165197
@pytest.mark.parametrize(
166198
"o",
167199
[

0 commit comments

Comments
 (0)