Skip to content

Commit 5f52efe

Browse files
authored
Fix recur-in-tail-position checks for let bindings and macros (#111)
1 parent c242215 commit 5f52efe

File tree

2 files changed

+85
-23
lines changed

2 files changed

+85
-23
lines changed

basilisp/compiler.py

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -544,65 +544,102 @@ def _fn_args_body(ctx: CompilerContext, arg_vec: vec.Vector, # pylint:disable=t
544544
FunctionArityDetails = Tuple[int, bool, llist.List]
545545

546546

547-
def _assert_no_recur(form: lseq.Seq) -> None:
547+
def _is_sym_macro(ctx: CompilerContext, form: sym.Symbol) -> bool:
548+
"""Determine if the symbol in the current context points to a macro.
549+
550+
This function is used in asserting that recur only appears in a tail position.
551+
Since macros expand at compile time, we can skip asserting in the un-expanded
552+
macro call, since macros are checked after macroexpansion."""
553+
if form.ns is not None:
554+
if form.ns == ctx.current_ns.name:
555+
v = ctx.current_ns.find(sym.symbol(form.name))
556+
if v is not None:
557+
return _is_macro(v)
558+
ns_sym = sym.symbol(form.ns)
559+
if ns_sym in ctx.current_ns.aliases:
560+
aliased_ns = ctx.current_ns.aliases[ns_sym]
561+
v = Var.find(sym.symbol(form.name, ns=aliased_ns))
562+
if v is not None:
563+
return _is_macro(v)
564+
565+
v = ctx.current_ns.find(form)
566+
if v is not None:
567+
return _is_macro(v)
568+
569+
return False
570+
571+
572+
def _assert_no_recur(ctx: CompilerContext, form: lseq.Seq) -> None:
548573
"""Assert that the iterable contains no recur special form."""
549574
for child in form:
550575
if isinstance(child, lseq.Seqable):
551-
_assert_no_recur(child.seq())
576+
_assert_no_recur(ctx, child.seq())
552577
elif isinstance(child, (llist.List, lseq.Seq)):
553-
if child.first == _RECUR:
554-
raise CompilerException("Recur appears outside tail position")
555-
_assert_no_recur(child)
578+
if isinstance(child.first, sym.Symbol):
579+
if _is_sym_macro(ctx, child.first):
580+
continue
581+
elif child.first == _RECUR:
582+
raise CompilerException(f"Recur appears outside tail position in {form}")
583+
elif child.first == _FN:
584+
continue
585+
_assert_no_recur(ctx, child)
556586

557587

558-
def _assert_recur_is_tail(form: lseq.Seq) -> None: # noqa: C901
588+
def _assert_recur_is_tail(ctx: CompilerContext, form: lseq.Seq) -> None: # noqa: C901
559589
"""Assert that recur special forms only appear in tail position in a function."""
560590
listlen = 0
561591
first_recur_index = None
562592
for i, child in enumerate(form): # pylint:disable=too-many-nested-blocks
563593
listlen += 1
564594
if isinstance(child, (llist.List, lseq.Seq)):
565-
if child.first == _RECUR:
595+
if _is_sym_macro(ctx, child.first):
596+
continue
597+
elif child.first == _RECUR:
566598
if first_recur_index is None:
567599
first_recur_index = i
568600
elif child.first == _DO:
569-
_assert_recur_is_tail(child)
601+
_assert_recur_is_tail(ctx, child)
602+
elif child.first == _FN:
603+
continue
570604
elif child.first == _IF:
571-
_assert_recur_is_tail(runtime.nth(child, 2))
605+
_assert_no_recur(ctx, lseq.sequence([runtime.nth(child, 1)]))
606+
_assert_recur_is_tail(ctx, lseq.sequence([runtime.nth(child, 2)]))
572607
try:
573-
_assert_recur_is_tail(runtime.nth(child, 3))
608+
_assert_recur_is_tail(ctx, lseq.sequence([runtime.nth(child, 3)]))
574609
except IndexError:
575610
pass
576611
elif child.first == _LET:
577-
_assert_no_recur(runtime.nth(child, 1).seq())
612+
for binding, val in seq(runtime.nth(child, 1)).grouped(2):
613+
_assert_no_recur(ctx, lseq.sequence([binding]))
614+
_assert_no_recur(ctx, lseq.sequence([val]))
578615
let_body = runtime.nthnext(child, 2)
579616
if let_body:
580-
_assert_recur_is_tail(let_body)
617+
_assert_recur_is_tail(ctx, let_body)
581618
elif child.first == _TRY:
582619
if isinstance(runtime.nth(child, 1), llist.List):
583-
_assert_recur_is_tail(llist.l(runtime.nth(child, 1)))
620+
_assert_recur_is_tail(ctx, lseq.sequence([runtime.nth(child, 1)]))
584621
catch_finally = runtime.nthnext(child, 2)
585622
if catch_finally:
586623
for clause in catch_finally:
587624
if isinstance(clause, llist.List):
588625
if clause.first == _CATCH:
589-
_assert_recur_is_tail(llist.l(runtime.nthnext(clause, 2)))
626+
_assert_recur_is_tail(ctx, lseq.sequence([runtime.nthnext(clause, 2)]))
590627
elif clause.first == _FINALLY:
591-
_assert_no_recur(llist.l(clause.rest))
628+
_assert_no_recur(ctx, clause.rest)
592629
elif child.first in {_DEF, _IMPORT, _INTEROP_CALL, _INTEROP_PROP, _THROW, _VAR}:
593-
_assert_no_recur(child)
630+
_assert_no_recur(ctx, child)
594631
else:
595-
_assert_recur_is_tail(child)
632+
_assert_recur_is_tail(ctx, child)
596633
else:
597634
if isinstance(child, lseq.Seqable):
598-
_assert_no_recur(child.seq())
635+
_assert_no_recur(ctx, child.seq())
599636

600637
if first_recur_index is not None:
601638
if first_recur_index != listlen - 1:
602639
raise CompilerException("Recur appears outside tail position")
603640

604641

605-
def _fn_arities(form: llist.List) -> Iterable[FunctionArityDetails]:
642+
def _fn_arities(ctx: CompilerContext, form: llist.List) -> Iterable[FunctionArityDetails]:
606643
"""Return the arities of a function definition and some additional details about
607644
the argument vector. Verify that all arities are compatible. In particular, this
608645
function will throw a CompilerException if any of the following are true:
@@ -624,15 +661,15 @@ def _fn_arities(form: llist.List) -> Iterable[FunctionArityDetails]:
624661
(fn a [] :a) ;=> '(([] :a))"""
625662
if not all(map(lambda f: isinstance(f, llist.List) and isinstance(f.first, vec.Vector), form)):
626663
assert isinstance(form.first, vec.Vector)
627-
_assert_recur_is_tail(form)
664+
_assert_recur_is_tail(ctx, form)
628665
yield len(form.first), False, form
629666
return
630667

631668
arg_counts: Dict[int, llist.List] = {}
632669
has_vargs = False
633670
vargs_len = None
634671
for arity in form:
635-
_assert_recur_is_tail(arity)
672+
_assert_recur_is_tail(ctx, arity)
636673

637674
# Verify each arity is unique
638675
arg_count = len(arity.first)
@@ -790,7 +827,7 @@ def _fn_ast(ctx: CompilerContext, form: llist.List) -> ASTStream:
790827

791828
with ctx.new_recur_point(name):
792829
rest_idx = 1 + int(has_name)
793-
arities = list(_fn_arities(form[rest_idx:]))
830+
arities = list(_fn_arities(ctx, form[rest_idx:]))
794831
if len(arities) == 0:
795832
raise CompilerException("Function def must have argument vector")
796833
elif len(arities) == 1:
@@ -1279,7 +1316,7 @@ def _list_ast(ctx: CompilerContext, form: llist.List) -> ASTStream:
12791316
# non-tail recur forms
12801317
try:
12811318
if ctx.recur_point.name:
1282-
_assert_recur_is_tail(lseq.sequence([expanded]))
1319+
_assert_recur_is_tail(ctx, lseq.sequence([expanded]))
12831320
except IndexError:
12841321
pass
12851322

tests/compiler_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,22 @@ def test_recur(ns_var: Var):
397397
assert 2 == lcompile("(last '(1 2))")
398398
assert 3 == lcompile("(last '(1 2 3))")
399399

400+
code = """
401+
(def rev-str
402+
(fn rev-str [s & args]
403+
(let [coerce (fn [in out]
404+
(if (seq (rest in))
405+
(recur (rest in) (cons (builtins/str (first in)) out))
406+
(cons (builtins/str (first in)) out)))]
407+
(.join \"\" (coerce (cons s args) '())))))
408+
"""
409+
410+
lcompile(code)
411+
412+
assert "a" == lcompile("(rev-str \"a\")")
413+
assert "ba" == lcompile("(rev-str \"a\" :b)")
414+
assert "3ba" == lcompile("(rev-str \"a\" :b 3)")
415+
400416

401417
def test_disallow_recur_in_special_forms(ns_var: Var):
402418
with pytest.raises(compiler.CompilerException):
@@ -440,6 +456,15 @@ def test_disallow_recur_outside_tail(ns_var: Var):
440456
with pytest.raises(compiler.CompilerException):
441457
lcompile("(fn [a] (let [a (recur \"a\")] a))")
442458

459+
with pytest.raises(compiler.CompilerException):
460+
lcompile("(fn [a] (let [a (do (recur \"a\"))] a))")
461+
462+
with pytest.raises(compiler.CompilerException):
463+
lcompile("(fn [a] (let [a (do :b (recur \"a\"))] a))")
464+
465+
with pytest.raises(compiler.CompilerException):
466+
lcompile("(fn [a] (let [a (do (recur \"a\") :c)] a))")
467+
443468
with pytest.raises(compiler.CompilerException):
444469
lcompile("(fn [a] (let [a \"a\"] (recur a) a))")
445470

0 commit comments

Comments
 (0)