Skip to content

Commit c7a6027

Browse files
authored
Recur correctly handles variadic arg lists (#116)
* Recur is only fixed to the current arity arglist and correctly handles variadic args * Add recur test * Test trampoline arguments * Fix lint error * More trampoline test cases
1 parent 5f52efe commit c7a6027

File tree

5 files changed

+166
-55
lines changed

5 files changed

+166
-55
lines changed

basilisp/compiler.py

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,11 @@ def new_frame(self, name):
130130

131131

132132
class RecurPoint:
133-
__slots__ = ('name', 'has_recur')
133+
__slots__ = ('name', 'args', 'has_recur')
134134

135-
def __init__(self, name: str) -> None:
135+
def __init__(self, name: str, args: vec.Vector) -> None:
136136
self.name = name
137+
self.args = args
137138
self.has_recur = False
138139

139140

@@ -159,8 +160,8 @@ def recur_point(self):
159160
return self._recur_points[-1]
160161

161162
@contextlib.contextmanager
162-
def new_recur_point(self, name: str):
163-
self._recur_points.append(RecurPoint(name))
163+
def new_recur_point(self, name: str, args: vec.Vector):
164+
self._recur_points.append(RecurPoint(name, args))
164165
yield
165166
self._recur_points.pop()
166167

@@ -715,7 +716,7 @@ def _compose_ifs(if_stmts: List[Dict[str, ast.AST]], orelse: List[ast.AST] = Non
715716

716717
def _single_arity_fn_ast(ctx: CompilerContext, name: str, fndef: llist.List) -> ASTStream:
717718
"""Generate Python AST nodes for a single-arity function."""
718-
with ctx.new_symbol_table(name):
719+
with ctx.new_symbol_table(name), ctx.new_recur_point(name, fndef.first):
719720
args, body, vargs = _fn_args_body(ctx, fndef.first, fndef.rest)
720721

721722
yield _dependency(_expressionize(body, name, args=args, vargs=vargs))
@@ -771,13 +772,26 @@ def __f_68(*multi_arity_args):
771772
has_rest = False
772773

773774
for arg_count, is_rest, arity in arities:
774-
has_rest = any([has_rest, is_rest])
775-
arity_name = f"{name}__arity{'_rest' if is_rest else arg_count}"
776-
777-
with ctx.new_symbol_table(arity_name):
778-
args, body, vargs = _fn_args_body(ctx, arity.first, arity.rest)
779-
780-
yield _dependency(_expressionize(body, arity_name, args=args, vargs=vargs))
775+
with ctx.new_recur_point(name, arity.first):
776+
has_rest = any([has_rest, is_rest])
777+
arity_name = f"{name}__arity{'_rest' if is_rest else arg_count}"
778+
779+
with ctx.new_symbol_table(arity_name):
780+
# Generate the arity function
781+
args, body, vargs = _fn_args_body(ctx, arity.first, arity.rest)
782+
yield _dependency(_expressionize(body, arity_name, args=args, vargs=vargs))
783+
784+
# If a recur point was established, we generate a trampoline version of the
785+
# generated function to allow repeated recursive calls without blowing up the
786+
# stack size.
787+
if ctx.recur_point.has_recur:
788+
yield _dependency(ast.Assign(targets=[ast.Name(id=arity_name, ctx=ast.Store())],
789+
value=ast.Call(func=_TRAMPOLINE_FN_NAME,
790+
args=[
791+
ast.Name(id=arity_name, ctx=ast.Load())],
792+
keywords=[])))
793+
794+
# Generate an if-statement branch for the arity-dispatch function
781795
compare_op = ast.GtE() if is_rest else ast.Eq()
782796
if_stmts.append(
783797
{"test": ast.Compare(left=ast.Call(func=_load_attr('len'),
@@ -808,15 +822,7 @@ def __f_68(*multi_arity_args):
808822
decorator_list=[],
809823
returns=None))
810824

811-
# If a recur point was established, we generate a trampoline version of the
812-
# generated function to allow repeated recursive calls without blowing up the
813-
# stack size.
814-
if ctx.recur_point.has_recur:
815-
yield _node(ast.Call(func=_TRAMPOLINE_FN_NAME,
816-
args=[ast.Name(id=ctx.recur_point.name, ctx=ast.Load())],
817-
keywords=[]))
818-
else:
819-
yield _node(ast.Name(id=name, ctx=ast.Load()))
825+
yield _node(ast.Name(id=name, ctx=ast.Load()))
820826

821827

822828
def _fn_ast(ctx: CompilerContext, form: llist.List) -> ASTStream:
@@ -825,18 +831,17 @@ def _fn_ast(ctx: CompilerContext, form: llist.List) -> ASTStream:
825831
has_name = isinstance(form[1], sym.Symbol)
826832
name = genname("__" + (munge(form[1].name) if has_name else _FN_PREFIX))
827833

828-
with ctx.new_recur_point(name):
829-
rest_idx = 1 + int(has_name)
830-
arities = list(_fn_arities(ctx, form[rest_idx:]))
831-
if len(arities) == 0:
832-
raise CompilerException("Function def must have argument vector")
833-
elif len(arities) == 1:
834-
_, _, fndef = arities[0]
835-
yield from _single_arity_fn_ast(ctx, name, fndef)
836-
return
837-
else:
838-
yield from _multi_arity_fn_ast(ctx, name, arities)
839-
return
834+
rest_idx = 1 + int(has_name)
835+
arities = list(_fn_arities(ctx, form[rest_idx:]))
836+
if len(arities) == 0:
837+
raise CompilerException("Function def must have argument vector")
838+
elif len(arities) == 1:
839+
_, _, fndef = arities[0]
840+
yield from _single_arity_fn_ast(ctx, name, fndef)
841+
return
842+
else:
843+
yield from _multi_arity_fn_ast(ctx, name, arities)
844+
return
840845

841846

842847
def _if_ast(ctx: CompilerContext, form: llist.List) -> ASTStream:
@@ -1080,8 +1085,10 @@ def _recur_ast(ctx: CompilerContext, form: llist.List) -> ASTStream:
10801085
expr_deps, exprs = _collection_literal_ast(ctx, form.rest)
10811086
yield from expr_deps
10821087

1088+
has_vargs = any([s == _AMPERSAND for s in ctx.recur_point.args])
10831089
yield _node(ast.Call(func=_TRAMPOLINE_ARGS_FN_NAME,
1084-
args=list(_unwrap_nodes(exprs)),
1090+
args=list(itertools.chain([ast.NameConstant(has_vargs)],
1091+
_unwrap_nodes(exprs))),
10851092
keywords=[]))
10861093
except IndexError:
10871094
raise CompilerException("Attempting to recur without recur point") from None

basilisp/core/__init__.lpy

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,18 @@
6666
(fn string? [o]
6767
(instance? builtins/str o)))
6868

69+
(def
70+
^{:doc "Return true if obj is a symbol."}
71+
symbol?
72+
(fn symbol? [o]
73+
(instance? basilisp.lang.symbol/Symbol o)))
74+
75+
(def
76+
^{:doc "Return true if obj is a keyword."}
77+
keyword?
78+
(fn keyword? [o]
79+
(instance? basilisp.lang.keyword/Keyword o)))
80+
6981
(def
7082
^{:doc "Return true if o is a list."}
7183
list?
@@ -134,25 +146,35 @@
134146

135147
(def
136148
^{:macro true
137-
:doc "Define a new function."}
149+
:doc "Define a new function with an optional docstring."}
138150
defn
139151
(fn defn [&form name & body]
140-
(let [body (concat body)
141-
doc (if (string? (first body))
142-
(first body)
143-
nil)
144-
fname (if doc
145-
(with-meta name {:doc doc})
146-
name)
147-
body (if doc
148-
(rest body)
149-
body)
150-
args (if (vector? (first body))
151-
(first body)
152-
nil) ;; Should throw here!
153-
body (rest body)]
152+
(if (symbol? name)
153+
nil ;; Do nothing!
154+
(throw (ex-info "First argument to defn must be a symbol"
155+
{:found name :type (builtins/type name)})))
156+
(let [body (concat body)
157+
doc (if (string? (first body))
158+
(first body)
159+
nil)
160+
fname (if doc
161+
(with-meta name {:doc doc})
162+
name)
163+
body (if doc
164+
(rest body)
165+
body)
166+
multi? (list? (first body))
167+
body (if multi?
168+
body
169+
(cons
170+
(if (vector? (first body))
171+
(first body)
172+
(throw
173+
(ex-info "Expected an argument vector"
174+
{:found (first body)})))
175+
(rest body)))]
154176
`(def ~fname
155-
(fn* ~fname ~args
177+
(fn* ~fname
156178
~@body)))))
157179

158180
(defn nth
@@ -185,3 +207,12 @@
185207
(if (seq (rest s))
186208
(recur (rest s))
187209
(first s)))
210+
211+
(defn +
212+
"Sum the arguments together. If no arguments given, returns 0."
213+
([] 0)
214+
([x] x)
215+
([x & args]
216+
(if (seq (rest args))
217+
(recur (operator/add x (first args)) (rest args))
218+
(operator/add x (first args)))))

basilisp/lang/runtime.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import itertools
33
import threading
44
import types
5-
from typing import Optional, List, Dict
5+
from typing import Optional, Dict, Tuple
66

77
from functional import seq
88
from pyrsistent import pmap, PMap, PSet, pset
@@ -506,15 +506,29 @@ def _collect_args(args) -> lseq.Seq:
506506

507507

508508
class _TrampolineArgs:
509-
__slots__ = ('_args', '_kwargs')
509+
__slots__ = ('_has_varargs', '_args', '_kwargs')
510510

511-
def __init__(self, *args, **kwargs):
511+
def __init__(self, has_varargs: bool, *args, **kwargs) -> None:
512+
self._has_varargs = has_varargs
512513
self._args = args
513514
self._kwargs = kwargs
514515

515516
@property
516-
def args(self) -> List:
517-
return self._args
517+
def args(self) -> Tuple:
518+
"""Return the arguments for a trampolined function. If the function
519+
that is being trampolined has varargs, unroll the final argument if
520+
it is a sequence."""
521+
if not self._has_varargs:
522+
return self._args
523+
524+
try:
525+
final = self._args[-1]
526+
if isinstance(final, lseq.Seq):
527+
inits = self._args[:-1]
528+
return tuple(itertools.chain(inits, final))
529+
return self._args
530+
except IndexError:
531+
return ()
518532

519533
@property
520534
def kwargs(self) -> Dict:

tests/compiler_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,45 @@ def test_recur(ns_var: Var):
414414
assert "3ba" == lcompile("(rev-str \"a\" :b 3)")
415415

416416

417+
def test_recur_arity(ns_var: Var):
418+
# Single arity function
419+
code = """
420+
(def ++
421+
(fn ++ [x & args]
422+
(if (seq (rest args))
423+
(recur (operator/add x (first args)) (rest args))
424+
(operator/add x (first args)))))
425+
"""
426+
427+
lcompile(code)
428+
429+
assert 3 == lcompile("(++ 1 2)")
430+
assert 6 == lcompile("(++ 1 2 3)")
431+
assert 10 == lcompile("(++ 1 2 3 4)")
432+
assert 15 == lcompile("(++ 1 2 3 4 5)")
433+
434+
# Multi-arity function
435+
code = """
436+
(def +++
437+
(fn +++
438+
([] 0)
439+
([x] x)
440+
([x & args]
441+
(if (seq (rest args))
442+
(recur (operator/add x (first args)) (rest args))
443+
(operator/add x (first args))))))
444+
"""
445+
446+
lcompile(code)
447+
448+
assert 0 == lcompile("(+++)")
449+
assert 1 == lcompile("(+++ 1)")
450+
assert 3 == lcompile("(+++ 1 2)")
451+
assert 6 == lcompile("(+++ 1 2 3)")
452+
assert 10 == lcompile("(+++ 1 2 3 4)")
453+
assert 15 == lcompile("(+++ 1 2 3 4 5)")
454+
455+
417456
def test_disallow_recur_in_special_forms(ns_var: Var):
418457
with pytest.raises(compiler.CompilerException):
419458
lcompile("(fn [a] (def b (recur \"a\")))")

tests/runtime_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,23 @@ def test_conj():
201201

202202
with pytest.raises(TypeError):
203203
runtime.conj("b", 1, "a")
204+
205+
206+
def test_trampoline_args():
207+
args = runtime._TrampolineArgs(True)
208+
assert () == args.args
209+
210+
args = runtime._TrampolineArgs(False, llist.l(2, 3, 4))
211+
assert (llist.l(2, 3, 4),) == args.args
212+
213+
args = runtime._TrampolineArgs(True, llist.l(2, 3, 4))
214+
assert (2, 3, 4) == args.args
215+
216+
args = runtime._TrampolineArgs(False, 1, 2, 3, llist.l(4, 5, 6))
217+
assert (1, 2, 3, llist.l(4, 5, 6)) == args.args
218+
219+
args = runtime._TrampolineArgs(True, 1, 2, 3, llist.l(4, 5, 6))
220+
assert (1, 2, 3, 4, 5, 6) == args.args
221+
222+
args = runtime._TrampolineArgs(True, 1, llist.l(2, 3, 4), 5, 6)
223+
assert (1, llist.l(2, 3, 4), 5, 6) == args.args

0 commit comments

Comments
 (0)