Skip to content

Commit 3bbc9fc

Browse files
authored
Support functions of multiple arities in compiler (#80)
* Support functions of multiple arities in compiler * Refactoring function AST code slightly * Test new function arities * Remove leading space * Remove unused function def
1 parent e041fc3 commit 3bbc9fc

File tree

2 files changed

+249
-12
lines changed

2 files changed

+249
-12
lines changed

basilisp/compiler.py

Lines changed: 181 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
_DO_PREFIX = 'lisp_do'
4141
_FN_PREFIX = 'lisp_fn'
4242
_IF_PREFIX = 'lisp_if'
43+
_MULTI_ARITY_ARG_NAME = 'multi_arity_args'
4344
_THROW_PREFIX = 'lisp_throw'
4445
_TRY_PREFIX = 'lisp_try'
4546
_NS_VAR = '__NS'
@@ -479,9 +480,11 @@ def _do_ast(ctx: CompilerContext, form: llist.List) -> ASTStream:
479480
func=ast.Name(id=do_fn_name, ctx=ast.Load()), args=[], keywords=[]))
480481

481482

483+
FunctionDefDetails = Tuple[List[ast.arg], ASTStream, Optional[ast.arg]]
484+
485+
482486
def _fn_args_body(ctx: CompilerContext, arg_vec: vec.Vector,
483-
body_exprs: llist.List
484-
) -> Tuple[List[ast.arg], ASTStream, Optional[ast.arg]]:
487+
body_exprs: llist.List) -> FunctionDefDetails:
485488
"""Generate the Python AST Nodes for a Lisp function argument vector
486489
and body expressions. Return a tuple of arg nodes and body AST nodes."""
487490
st = ctx.symbol_table
@@ -522,22 +525,189 @@ def _fn_args_body(ctx: CompilerContext, arg_vec: vec.Vector,
522525
return args, cast(ASTStream, body), vargs
523526

524527

525-
def _fn_ast(ctx: CompilerContext, form: llist.List) -> ASTStream:
526-
"""Generate a Python AST Node for an anonymous function."""
527-
assert form.first == _FN
528-
has_name = isinstance(form[1], sym.Symbol)
529-
name = genname("__" + (munge(form[1].name) if has_name else _FN_PREFIX))
528+
FunctionArityDetails = Tuple[int, bool, llist.List]
529+
530+
531+
def _fn_arities(form: llist.List) -> Iterable[FunctionArityDetails]:
532+
"""Return the arities of a function definition and some additional details about
533+
the argument vector. Verify that all arities are compatible. In particular, this
534+
function will throw a CompilerException if any of the following are true:
535+
- two function definitions have the same number of arguments
536+
- two function definitions have a rest parameter
537+
- any function definition has the same number of arguments as a definition
538+
with a rest parameter
539+
540+
Given a function such as this:
541+
542+
(fn a
543+
([] :a)
544+
([a] a))
545+
546+
Returns a generator yielding: '(([] :a) ([a] a))
547+
548+
Single arity functions yield the rest:
549+
550+
(fn a [] :a) ;=> '(([] :a))"""
551+
if not all(map(lambda f: isinstance(f, llist.List) and isinstance(f.first, vec.Vector), form)):
552+
assert isinstance(form.first, vec.Vector)
553+
yield len(form.first), False, form
554+
return
530555

531-
arg_idx = 1 + int(has_name)
532-
body_idx = 2 + int(has_name)
556+
arg_counts: Dict[int, llist.List] = {}
557+
has_vargs = False
558+
vargs_len = None
559+
for arity in form:
560+
# Verify each arity is unique
561+
arg_count = len(arity.first)
562+
if arg_count in arg_counts:
563+
raise CompilerException("Each arity in multi-arity fn must be unique",
564+
[arity, arg_counts[arg_count]])
565+
566+
# Verify that only one arity contains a rest-param
567+
is_rest = False
568+
for arg in arity.first:
569+
if arg == _AMPERSAND:
570+
if has_vargs:
571+
raise CompilerException("Only one arity in multi-arity fn may have rest param")
572+
is_rest = True
573+
has_vargs = True
574+
arg_count -= 1
575+
vargs_len = arg_count
576+
577+
# Verify that arities do not exceed rest-param arity
578+
if vargs_len is not None and any([c >= vargs_len for c in arg_counts.keys()]):
579+
raise CompilerException("No arity in multi-arity fn may exceed the rest param arity")
580+
581+
# Put this in last so it does not conflict with the above checks
582+
arg_counts[arg_count] = arity
583+
584+
yield arg_count, is_rest, arity
585+
586+
587+
def _compose_ifs(if_stmts: List[Dict[str, ast.AST]], orelse: List[ast.AST] = None) -> ast.If:
588+
"""Compose a series of If statements into nested elifs, with
589+
an optional terminating else."""
590+
first = if_stmts[0]
591+
try:
592+
rest = if_stmts[1:]
593+
return ast.If(test=first["test"],
594+
body=[first["body"]],
595+
orelse=[_compose_ifs(rest, orelse=orelse)])
596+
except IndexError:
597+
return ast.If(test=first["test"],
598+
body=[first["body"]],
599+
orelse=Maybe(orelse).or_else_get([]))
533600

534-
assert isinstance(form[arg_idx], vec.Vector)
535601

602+
def _single_arity_fn_ast(ctx: CompilerContext, name: str, fndef: llist.List) -> ASTStream:
603+
"""Generate Python AST nodes for a single-arity function."""
536604
with ctx.new_symbol_table(name):
537-
args, body, vargs = _fn_args_body(ctx, form[arg_idx], form[body_idx:])
605+
args, body, vargs = _fn_args_body(ctx, fndef.first, fndef.rest)
538606

539607
yield _dependency(_expressionize(body, name, args=args, vargs=vargs))
540608
yield _node(ast.Name(id=name, ctx=ast.Load()))
609+
return
610+
611+
612+
def _multi_arity_fn_ast(ctx: CompilerContext, name: str, arities: List[FunctionArityDetails]) -> ASTStream:
613+
"""Generate Python AST nodes for multi-arity Basilisp function definitions.
614+
615+
For example, a multi-arity function like this:
616+
617+
(def f
618+
(fn f
619+
([] (print "No args"))
620+
([arg]
621+
(print arg))
622+
([arg & rest]
623+
(print (concat [arg] rest)))))
624+
625+
Would yield a function definition in Python code like this:
626+
627+
def __f_68__arity0():
628+
return print_('No args')
629+
630+
631+
def __f_68__arity1(arg_69):
632+
return print_(arg_69)
633+
634+
635+
def __f_68__arity_rest(arg_70, *rest_71):
636+
rest_72 = runtime._collect_args(rest_71)
637+
return print_(concat(vec.vector([arg_70], meta=None), rest_72))
638+
639+
640+
def __f_68(*multi_arity_args):
641+
if len(multi_arity_args) == 0:
642+
return __f_68__arity0(*multi_arity_args)
643+
elif len(multi_arity_args) == 1:
644+
return __f_68__arity1(*multi_arity_args)
645+
elif len(multi_arity_args) >= 2:
646+
return __f_68__arity2(*multi_arity_args)
647+
648+
649+
f = __f_68"""
650+
if_stmts: List[Dict[str, ast.AST]] = []
651+
multi_arity_args_arg = _load_attr(_MULTI_ARITY_ARG_NAME)
652+
has_rest = False
653+
654+
for arg_count, is_rest, arity in arities:
655+
has_rest = any([has_rest, is_rest])
656+
arity_name = f"{name}__arity{'_rest' if is_rest else arg_count}"
657+
658+
with ctx.new_symbol_table(arity_name):
659+
args, body, vargs = _fn_args_body(ctx, arity.first, arity.rest)
660+
661+
yield _dependency(_expressionize(body, arity_name, args=args, vargs=vargs))
662+
compare_op = ast.GtE() if is_rest else ast.Eq()
663+
if_stmts.append(
664+
{"test": ast.Compare(left=ast.Call(func=_load_attr('len'),
665+
args=[multi_arity_args_arg],
666+
keywords=[]),
667+
ops=[compare_op],
668+
comparators=[ast.Num(arg_count)]),
669+
"body": ast.Return(value=ast.Call(func=_load_attr(arity_name),
670+
args=[ast.Starred(value=multi_arity_args_arg, ctx=ast.Load())],
671+
keywords=[]))})
672+
673+
assert len(if_stmts) == len(arities)
674+
675+
yield _dependency(ast.FunctionDef(
676+
name=name,
677+
args=ast.arguments(
678+
args=[],
679+
kwarg=None,
680+
vararg=ast.arg(arg=_MULTI_ARITY_ARG_NAME, annotation=None),
681+
kwonlyargs=[],
682+
defaults=[],
683+
kw_defaults=[]),
684+
body=[_compose_ifs(if_stmts),
685+
ast.Raise(exc=ast.Call(func=_load_attr('basilisp.lang.runtime.RuntimeException'),
686+
args=[ast.Str(f"Wrong number of args passed to function: {name}")],
687+
keywords=[]),
688+
cause=None)],
689+
decorator_list=[],
690+
returns=None))
691+
yield _node(ast.Name(id=name, ctx=ast.Load()))
692+
693+
694+
def _fn_ast(ctx: CompilerContext, form: llist.List) -> ASTStream:
695+
"""Generate a Python AST Nodes for function definitions."""
696+
assert form.first == _FN
697+
has_name = isinstance(form[1], sym.Symbol)
698+
name = genname("__" + (munge(form[1].name) if has_name else _FN_PREFIX))
699+
700+
rest_idx = 1 + int(has_name)
701+
arities = list(_fn_arities(form[rest_idx:]))
702+
if len(arities) == 0:
703+
raise CompilerException("Function def must have argument vector")
704+
elif len(arities) == 1:
705+
_, _, fndef = arities[0]
706+
yield from _single_arity_fn_ast(ctx, name, fndef)
707+
return
708+
else:
709+
yield from _multi_arity_fn_ast(ctx, name, arities)
710+
return
541711

542712

543713
def _if_ast(ctx: CompilerContext, form: llist.List) -> ASTStream:

tests/compiler_test.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def test_do(ns_var: Var):
160160
assert lcompile("last-name") == "Vader"
161161

162162

163-
def test_fn(ns_var: Var):
163+
def test_single_arity_fn(ns_var: Var):
164164
code = """
165165
(def string-upper (fn* string-upper [s] (.upper s)))
166166
"""
@@ -171,6 +171,16 @@ def test_fn(ns_var: Var):
171171
assert callable(fvar.value)
172172
assert fvar.value("lower") == "LOWER"
173173

174+
code = """
175+
(def string-upper (fn* string-upper ([s] (.upper s))))
176+
"""
177+
ns_name = ns_var.value.name
178+
fvar = lcompile(code)
179+
assert fvar == Var.find_in_ns(
180+
sym.symbol(ns_name), sym.symbol('string-upper'))
181+
assert callable(fvar.value)
182+
assert fvar.value("lower") == "LOWER"
183+
174184
code = """
175185
(def string-lower #(.lower %))
176186
"""
@@ -182,6 +192,63 @@ def test_fn(ns_var: Var):
182192
assert fvar.value("UPPER") == "upper"
183193

184194

195+
def test_multi_arity_fn(ns_var: Var):
196+
with pytest.raises(compiler.CompilerException):
197+
lcompile('(fn f)')
198+
199+
with pytest.raises(compiler.CompilerException):
200+
lcompile("""
201+
(def f
202+
(fn* f
203+
([] :no-args)
204+
([] :also-no-args)))
205+
""")
206+
207+
with pytest.raises(compiler.CompilerException):
208+
lcompile("""
209+
(def f
210+
(fn* f
211+
([& args] (concat [:no-starter] args))
212+
([s & args] (concat [s] args))))
213+
""")
214+
215+
with pytest.raises(compiler.CompilerException):
216+
lcompile("""
217+
(def f
218+
(fn* f
219+
([s] (concat [s] :one-arg))
220+
([& args] (concat [:rest-params] args))))
221+
""")
222+
223+
code = """
224+
(def multi-fn
225+
(fn* multi-fn
226+
([] :no-args)
227+
([s] s)
228+
([s & args] (concat [s] args))))
229+
"""
230+
ns_name = ns_var.value.name
231+
fvar = lcompile(code)
232+
assert fvar == Var.find_in_ns(
233+
sym.symbol(ns_name), sym.symbol('multi-fn'))
234+
assert callable(fvar.value)
235+
assert fvar.value() == kw.keyword('no-args')
236+
assert fvar.value('STRING') == 'STRING'
237+
assert fvar.value(kw.keyword('first-arg'), 'second-arg', 3) == llist.l(kw.keyword('first-arg'), 'second-arg', 3)
238+
239+
with pytest.raises(runtime.RuntimeException):
240+
code = """
241+
(def angry-multi-fn
242+
(fn* angry-multi-fn
243+
([] :send-me-an-arg!)
244+
([i] i)
245+
([i j] (concat [i] [j]))))
246+
"""
247+
ns_name = ns_var.value.name
248+
fvar = lcompile(code)
249+
fvar.value(1, 2, 3)
250+
251+
185252
def test_fn_call(ns_var: Var):
186253
code = """
187254
(def string-upper (fn* [s] (.upper s)))

0 commit comments

Comments
 (0)