Skip to content

Commit 5c6972e

Browse files
authored
Fix nth and get type support (#570)
* `nth` previously would respond to nearly any collection type, which was incorrect. `nth` should only respond to sequential types, not maps and sets which _can produce_ sequences but are not themselves sequential. * `get` should technically respond to all types, though in most cases it returns `nil` (or its default value). It should respond to set types as if by calling them with the key. It should respond to all Python types as by `__getitem__`.
1 parent ecc545f commit 5c6972e

File tree

7 files changed

+196
-107
lines changed

7 files changed

+196
-107
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2121

2222
### Changed
2323
* Basilisp set and map types are now backed by the HAMT provided by `immutables` (#557)
24+
* `get` now responds `nil` (or its default) for any unsupported types (#570)
25+
* `nth` now supports only sequential collections (or `nil`) and will throw an exception for any invalid types (#570)
2426

2527
### Fixed
2628
* Fixed a bug where the Basilisp AST nodes for return values of `deftype` members could be marked as _statements_ rather than _expressions_, resulting in an incorrect `nil` return (#523)

src/basilisp/core.lpy

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -321,9 +321,13 @@
321321
~@body)))))
322322

323323
(defn nth
324-
"Returns the ith element of coll (0-indexed), if it exists.
325-
nil otherwise. If i is out of bounds, throws an IndexError unless
326-
notfound is specified."
324+
"Returns the `i`th element of `coll` (0-indexed), if it exists or `nil`
325+
otherwise. If `i` is out of bounds, throws an `IndexError` unless `notfound`
326+
is specified.
327+
328+
`coll` may be any sequential collection type (such as a list or vector), string,
329+
or `nil`. If `coll` is not one of the supported types, a `TypeError` will be
330+
thrown."
327331
([coll i]
328332
(basilisp.lang.runtime/nth coll i))
329333
([coll i notfound]
@@ -1912,7 +1916,11 @@
19121916
(apply (.-dissoc m) ks))
19131917

19141918
(defn get
1915-
"Return the entry of m corresponding to k if it exists or nil/default otherwise."
1919+
"Return the entry of `m` corresponding to `k` if it exists or `nil`/`default`
1920+
otherwise.
1921+
1922+
`m` may be any associative type (such as a vector or map), set type, or string.
1923+
If `m` is not one of the supported types, `get` always returns `nil`/`default`."
19161924
([m k]
19171925
(basilisp.lang.runtime/get m k))
19181926
([m k default]

src/basilisp/lang/compiler/analyzer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1580,7 +1580,7 @@ def __deftype_and_reify_impls_are_all_abstract( # pylint: disable=too-many-bran
15801580
if not interface.var.is_bound:
15811581
logger.log(
15821582
TRACE,
1583-
f"{special_form} interface Var '{interface.form}' is not bound"
1583+
f"{special_form} interface Var '{interface.form}' is not bound "
15841584
"and cannot be checked for abstractness; deferring to runtime",
15851585
)
15861586
unverifiably_abstract.add(interface)

src/basilisp/lang/runtime.py

Lines changed: 58 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
from basilisp.lang.reference import ReferenceBase
5555
from basilisp.lang.typing import CompilerOpts, LispNumber
5656
from basilisp.lang.util import OBJECT_DUNDER_METHODS, demunge, is_abstract, munge
57-
from basilisp.logconfig import TRACE
5857
from basilisp.util import Maybe
5958

6059
logger = logging.getLogger(__name__)
@@ -1032,23 +1031,12 @@ def nth(coll, i: int, notfound=__nth_sentinel):
10321031
"""Returns the ith element of coll (0-indexed), if it exists.
10331032
None otherwise. If i is out of bounds, throws an IndexError unless
10341033
notfound is specified."""
1035-
try:
1036-
for j, e in enumerate(coll):
1037-
if i == j:
1038-
return e
1039-
except TypeError:
1040-
pass
1041-
else:
1042-
if notfound is not __nth_sentinel:
1043-
return notfound
1044-
raise IndexError(f"Index {i} out of bounds")
1045-
10461034
raise TypeError(f"nth not supported on object of type {type(coll)}")
10471035

10481036

10491037
@nth.register(type(None))
10501038
def _nth_none(_: None, i: int, notfound=__nth_sentinel) -> None:
1051-
return None
1039+
return notfound if notfound is not __nth_sentinel else None
10521040

10531041

10541042
@nth.register(Sequence)
@@ -1059,18 +1047,67 @@ def _nth_sequence(coll: Sequence, i: int, notfound=__nth_sentinel):
10591047
if notfound is not __nth_sentinel:
10601048
return notfound
10611049
raise ex
1062-
except TypeError as ex:
1063-
# Log these at TRACE so they don't gum up the DEBUG logs since most
1064-
# cases where this exception occurs are not bugs.
1065-
logger.log(TRACE, "Ignored %s: %s", type(ex).__name__, ex)
1050+
1051+
1052+
@nth.register(ISeq)
1053+
def _nth_iseq(coll: ISeq, i: int, notfound=__nth_sentinel):
1054+
for j, e in enumerate(coll):
1055+
if i == j:
1056+
return e
1057+
1058+
if notfound is not __nth_sentinel:
1059+
return notfound
1060+
1061+
raise IndexError(f"Index {i} out of bounds")
1062+
1063+
1064+
@functools.singledispatch
1065+
def contains(coll, k):
1066+
"""Return true if o contains the key k."""
1067+
return k in coll
1068+
1069+
1070+
@contains.register(IAssociative)
1071+
def _contains_iassociative(coll, k):
1072+
return coll.contains(k)
1073+
1074+
1075+
@functools.singledispatch
1076+
def get(m, k, default=None): # pylint: disable=unused-argument
1077+
"""Return the value of k in m. Return default if k not found in m."""
1078+
return default
1079+
1080+
1081+
@get.register(dict)
1082+
@get.register(list)
1083+
@get.register(str)
1084+
def _get_others(m, k, default=None):
1085+
try:
1086+
return m[k]
1087+
except (KeyError, IndexError):
1088+
return default
1089+
1090+
1091+
@get.register(IPersistentSet)
1092+
@get.register(frozenset)
1093+
@get.register(set)
1094+
def _get_settypes(m, k, default=None):
1095+
if k in m:
1096+
return k
1097+
return default
1098+
1099+
1100+
@get.register(ILookup)
1101+
def _get_ilookup(m, k, default=None):
1102+
return m.val_at(k, default)
10661103

10671104

10681105
@functools.singledispatch
10691106
def assoc(m, *kvs):
10701107
"""Associate keys to values in associative data structure m. If m is None,
10711108
returns a new Map with key-values kvs."""
10721109
raise TypeError(
1073-
f"Object of type {type(m)} does not implement Associative interface"
1110+
f"Object of type {type(m)} does not implement IAssociative interface"
10741111
)
10751112

10761113

@@ -1090,7 +1127,7 @@ def update(m, k, f, *args):
10901127
calling f(old_v, *args). If m is None, use an empty map. If k is not in m, old_v will be
10911128
None."""
10921129
raise TypeError(
1093-
f"Object of type {type(m)} does not implement Associative interface"
1130+
f"Object of type {type(m)} does not implement IAssociative interface"
10941131
)
10951132

10961133

@@ -1112,7 +1149,8 @@ def conj(coll, *xs):
11121149
depending on the type of coll. conj returns the same type as coll. If coll
11131150
is None, return a list with xs conjoined."""
11141151
raise TypeError(
1115-
f"Object of type {type(coll)} does not implement Collection interface"
1152+
f"Object of type {type(coll)} does not implement "
1153+
"IPersistentCollection interface"
11161154
)
11171155

11181156

@@ -1228,32 +1266,6 @@ def __ge__(self, other):
12281266
return lseq.sequence(sorted(coll, key=key))
12291267

12301268

1231-
@functools.singledispatch
1232-
def contains(coll, k):
1233-
"""Return true if o contains the key k."""
1234-
return k in coll
1235-
1236-
1237-
@contains.register(IAssociative)
1238-
def _contains_iassociative(coll, k):
1239-
return coll.contains(k)
1240-
1241-
1242-
@functools.singledispatch
1243-
def get(m, k, default=None):
1244-
"""Return the value of k in m. Return default if k not found in m."""
1245-
try:
1246-
return m[k]
1247-
except (KeyError, IndexError, TypeError) as e:
1248-
logger.log(TRACE, "Ignored %s: %s", type(e).__name__, e)
1249-
return default
1250-
1251-
1252-
@get.register(ILookup)
1253-
def _get_ilookup(m, k, default=None):
1254-
return m.val_at(k, default)
1255-
1256-
12571269
def is_special_form(s: sym.Symbol) -> bool:
12581270
"""Return True if s names a special form."""
12591271
return s in _SPECIAL_FORMS

tests/basilisp/core_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,10 +1415,8 @@ class TestRandom:
14151415
COLLS = [
14161416
vec.v("a", "b", "c"),
14171417
llist.l("a", "b", "c"),
1418-
lset.s("a", "b", "c"),
14191418
["a", "b", "c"],
14201419
("a", "b", "c"),
1421-
{"a", "b", "c"},
14221420
]
14231421

14241422
@pytest.fixture(scope="class", params=COLLS)

tests/basilisp/runtime_test.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def test_apply():
156156

157157
def test_nth():
158158
assert None is runtime.nth(None, 1)
159+
assert "not found" == runtime.nth(None, 4, "not found")
159160
assert "l" == runtime.nth("hello world", 2)
160161
assert "l" == runtime.nth(["h", "e", "l", "l", "o"], 2)
161162
assert "l" == runtime.nth(llist.l("h", "e", "l", "l", "o"), 2)
@@ -164,20 +165,72 @@ def test_nth():
164165

165166
assert "Z" == runtime.nth(llist.l("h", "e", "l", "l", "o"), 7, "Z")
166167
assert "Z" == runtime.nth(lseq.sequence(["h", "e", "l", "l", "o"]), 7, "Z")
168+
assert "Z" == runtime.nth(vec.v("h", "e", "l", "l", "o"), 7, "Z")
167169

168170
with pytest.raises(IndexError):
169171
runtime.nth(llist.l("h", "e", "l", "l", "o"), 7)
170172

171173
with pytest.raises(IndexError):
172174
runtime.nth(lseq.sequence(["h", "e", "l", "l", "o"]), 7)
173175

176+
with pytest.raises(IndexError):
177+
runtime.nth(vec.v("h", "e", "l", "l", "o"), 7)
178+
179+
with pytest.raises(TypeError):
180+
runtime.nth(lmap.Map.empty(), 2)
181+
182+
with pytest.raises(TypeError):
183+
runtime.nth(lmap.map({"a": 1, "b": 2, "c": 3}), 2)
184+
185+
with pytest.raises(TypeError):
186+
runtime.nth(lset.Set.empty(), 2)
187+
188+
with pytest.raises(TypeError):
189+
runtime.nth(lset.s(1, 2, 3), 2)
190+
174191
with pytest.raises(TypeError):
175192
runtime.nth(3, 1)
176193

177194
with pytest.raises(TypeError):
178195
runtime.nth(3, 1, "Z")
179196

180197

198+
def test_get():
199+
assert None is runtime.get(None, "a")
200+
assert keyword.keyword("nada") is runtime.get(None, "a", keyword.keyword("nada"))
201+
assert None is runtime.get(3, "a")
202+
assert keyword.keyword("nada") is runtime.get(3, "a", keyword.keyword("nada"))
203+
assert 1 == runtime.get(lmap.map({"a": 1}), "a")
204+
assert None is runtime.get(lmap.map({"a": 1}), "b")
205+
assert 2 == runtime.get(lmap.map({"a": 1}), "b", 2)
206+
assert 1 == runtime.get(vec.v(1, 2, 3), 0)
207+
assert None is runtime.get(vec.v(1, 2, 3), 3)
208+
assert "nada" == runtime.get(vec.v(1, 2, 3), 3, "nada")
209+
assert "l" == runtime.get("hello world", 2)
210+
assert None is runtime.get("hello world", 50)
211+
assert "nada" == runtime.get("hello world", 50, "nada")
212+
assert "l" == runtime.get(["h", "e", "l", "l", "o"], 2)
213+
assert None is runtime.get(["h", "e", "l", "l", "o"], 50)
214+
assert "nada" == runtime.get(["h", "e", "l", "l", "o"], 50, "nada")
215+
assert 1 == runtime.get({"a": 1}, "a")
216+
assert None is runtime.get({"a": 1}, "b")
217+
assert 2 == runtime.get({"a": 1}, "b", 2)
218+
assert "a" == runtime.get({"a", "b", "c"}, "a")
219+
assert None is runtime.get({"a", "b", "c"}, "d")
220+
assert 2 == runtime.get({"a", "b", "c"}, "d", 2)
221+
assert "a" == runtime.get(frozenset({"a", "b", "c"}), "a")
222+
assert None is runtime.get(frozenset({"a", "b", "c"}), "d")
223+
assert 2 == runtime.get(frozenset({"a", "b", "c"}), "d", 2)
224+
assert "a" == runtime.get(lset.set({"a", "b", "c"}), "a")
225+
assert None is runtime.get(lset.set({"a", "b", "c"}), "d")
226+
assert 2 == runtime.get(lset.set({"a", "b", "c"}), "d", 2)
227+
228+
# lists are "supported" by virtue of the fact that `get`-ing them does not fail
229+
assert None is runtime.get(llist.l(1, 2, 3), 0)
230+
assert None is runtime.get(llist.l(1, 2, 3), 3)
231+
assert "nada" == runtime.get(llist.l(1, 2, 3), 0, "nada")
232+
233+
181234
def test_assoc():
182235
assert lmap.Map.empty() == runtime.assoc(None)
183236
assert lmap.map({"a": 1}) == runtime.assoc(None, "a", 1)

0 commit comments

Comments
 (0)