@@ -544,65 +544,102 @@ def _fn_args_body(ctx: CompilerContext, arg_vec: vec.Vector, # pylint:disable=t
544
544
FunctionArityDetails = Tuple [int , bool , llist .List ]
545
545
546
546
547
- def _assert_no_recur (form : lseq .Seq ) -> None :
547
+ def _is_sym_macro (ctx : CompilerContext , form : sym .Symbol ) -> bool :
548
+ """Determine if the symbol in the current context points to a macro.
549
+
550
+ This function is used in asserting that recur only appears in a tail position.
551
+ Since macros expand at compile time, we can skip asserting in the un-expanded
552
+ macro call, since macros are checked after macroexpansion."""
553
+ if form .ns is not None :
554
+ if form .ns == ctx .current_ns .name :
555
+ v = ctx .current_ns .find (sym .symbol (form .name ))
556
+ if v is not None :
557
+ return _is_macro (v )
558
+ ns_sym = sym .symbol (form .ns )
559
+ if ns_sym in ctx .current_ns .aliases :
560
+ aliased_ns = ctx .current_ns .aliases [ns_sym ]
561
+ v = Var .find (sym .symbol (form .name , ns = aliased_ns ))
562
+ if v is not None :
563
+ return _is_macro (v )
564
+
565
+ v = ctx .current_ns .find (form )
566
+ if v is not None :
567
+ return _is_macro (v )
568
+
569
+ return False
570
+
571
+
572
+ def _assert_no_recur (ctx : CompilerContext , form : lseq .Seq ) -> None :
548
573
"""Assert that the iterable contains no recur special form."""
549
574
for child in form :
550
575
if isinstance (child , lseq .Seqable ):
551
- _assert_no_recur (child .seq ())
576
+ _assert_no_recur (ctx , child .seq ())
552
577
elif isinstance (child , (llist .List , lseq .Seq )):
553
- if child .first == _RECUR :
554
- raise CompilerException ("Recur appears outside tail position" )
555
- _assert_no_recur (child )
578
+ if isinstance (child .first , sym .Symbol ):
579
+ if _is_sym_macro (ctx , child .first ):
580
+ continue
581
+ elif child .first == _RECUR :
582
+ raise CompilerException (f"Recur appears outside tail position in { form } " )
583
+ elif child .first == _FN :
584
+ continue
585
+ _assert_no_recur (ctx , child )
556
586
557
587
558
- def _assert_recur_is_tail (form : lseq .Seq ) -> None : # noqa: C901
588
+ def _assert_recur_is_tail (ctx : CompilerContext , form : lseq .Seq ) -> None : # noqa: C901
559
589
"""Assert that recur special forms only appear in tail position in a function."""
560
590
listlen = 0
561
591
first_recur_index = None
562
592
for i , child in enumerate (form ): # pylint:disable=too-many-nested-blocks
563
593
listlen += 1
564
594
if isinstance (child , (llist .List , lseq .Seq )):
565
- if child .first == _RECUR :
595
+ if _is_sym_macro (ctx , child .first ):
596
+ continue
597
+ elif child .first == _RECUR :
566
598
if first_recur_index is None :
567
599
first_recur_index = i
568
600
elif child .first == _DO :
569
- _assert_recur_is_tail (child )
601
+ _assert_recur_is_tail (ctx , child )
602
+ elif child .first == _FN :
603
+ continue
570
604
elif child .first == _IF :
571
- _assert_recur_is_tail (runtime .nth (child , 2 ))
605
+ _assert_no_recur (ctx , lseq .sequence ([runtime .nth (child , 1 )]))
606
+ _assert_recur_is_tail (ctx , lseq .sequence ([runtime .nth (child , 2 )]))
572
607
try :
573
- _assert_recur_is_tail (runtime .nth (child , 3 ))
608
+ _assert_recur_is_tail (ctx , lseq . sequence ([ runtime .nth (child , 3 )] ))
574
609
except IndexError :
575
610
pass
576
611
elif child .first == _LET :
577
- _assert_no_recur (runtime .nth (child , 1 ).seq ())
612
+ for binding , val in seq (runtime .nth (child , 1 )).grouped (2 ):
613
+ _assert_no_recur (ctx , lseq .sequence ([binding ]))
614
+ _assert_no_recur (ctx , lseq .sequence ([val ]))
578
615
let_body = runtime .nthnext (child , 2 )
579
616
if let_body :
580
- _assert_recur_is_tail (let_body )
617
+ _assert_recur_is_tail (ctx , let_body )
581
618
elif child .first == _TRY :
582
619
if isinstance (runtime .nth (child , 1 ), llist .List ):
583
- _assert_recur_is_tail (llist . l ( runtime .nth (child , 1 )))
620
+ _assert_recur_is_tail (ctx , lseq . sequence ([ runtime .nth (child , 1 )] ))
584
621
catch_finally = runtime .nthnext (child , 2 )
585
622
if catch_finally :
586
623
for clause in catch_finally :
587
624
if isinstance (clause , llist .List ):
588
625
if clause .first == _CATCH :
589
- _assert_recur_is_tail (llist . l ( runtime .nthnext (clause , 2 )))
626
+ _assert_recur_is_tail (ctx , lseq . sequence ([ runtime .nthnext (clause , 2 )] ))
590
627
elif clause .first == _FINALLY :
591
- _assert_no_recur (llist . l ( clause .rest ) )
628
+ _assert_no_recur (ctx , clause .rest )
592
629
elif child .first in {_DEF , _IMPORT , _INTEROP_CALL , _INTEROP_PROP , _THROW , _VAR }:
593
- _assert_no_recur (child )
630
+ _assert_no_recur (ctx , child )
594
631
else :
595
- _assert_recur_is_tail (child )
632
+ _assert_recur_is_tail (ctx , child )
596
633
else :
597
634
if isinstance (child , lseq .Seqable ):
598
- _assert_no_recur (child .seq ())
635
+ _assert_no_recur (ctx , child .seq ())
599
636
600
637
if first_recur_index is not None :
601
638
if first_recur_index != listlen - 1 :
602
639
raise CompilerException ("Recur appears outside tail position" )
603
640
604
641
605
- def _fn_arities (form : llist .List ) -> Iterable [FunctionArityDetails ]:
642
+ def _fn_arities (ctx : CompilerContext , form : llist .List ) -> Iterable [FunctionArityDetails ]:
606
643
"""Return the arities of a function definition and some additional details about
607
644
the argument vector. Verify that all arities are compatible. In particular, this
608
645
function will throw a CompilerException if any of the following are true:
@@ -624,15 +661,15 @@ def _fn_arities(form: llist.List) -> Iterable[FunctionArityDetails]:
624
661
(fn a [] :a) ;=> '(([] :a))"""
625
662
if not all (map (lambda f : isinstance (f , llist .List ) and isinstance (f .first , vec .Vector ), form )):
626
663
assert isinstance (form .first , vec .Vector )
627
- _assert_recur_is_tail (form )
664
+ _assert_recur_is_tail (ctx , form )
628
665
yield len (form .first ), False , form
629
666
return
630
667
631
668
arg_counts : Dict [int , llist .List ] = {}
632
669
has_vargs = False
633
670
vargs_len = None
634
671
for arity in form :
635
- _assert_recur_is_tail (arity )
672
+ _assert_recur_is_tail (ctx , arity )
636
673
637
674
# Verify each arity is unique
638
675
arg_count = len (arity .first )
@@ -790,7 +827,7 @@ def _fn_ast(ctx: CompilerContext, form: llist.List) -> ASTStream:
790
827
791
828
with ctx .new_recur_point (name ):
792
829
rest_idx = 1 + int (has_name )
793
- arities = list (_fn_arities (form [rest_idx :]))
830
+ arities = list (_fn_arities (ctx , form [rest_idx :]))
794
831
if len (arities ) == 0 :
795
832
raise CompilerException ("Function def must have argument vector" )
796
833
elif len (arities ) == 1 :
@@ -1279,7 +1316,7 @@ def _list_ast(ctx: CompilerContext, form: llist.List) -> ASTStream:
1279
1316
# non-tail recur forms
1280
1317
try :
1281
1318
if ctx .recur_point .name :
1282
- _assert_recur_is_tail (lseq .sequence ([expanded ]))
1319
+ _assert_recur_is_tail (ctx , lseq .sequence ([expanded ]))
1283
1320
except IndexError :
1284
1321
pass
1285
1322
0 commit comments