Skip to content

Commit 77b1871

Browse files
authored
Resolve macro symbols using all possible heuristics (#183)
* Resolve macro symbols using all possible heuristics * Set current namespace out-of-band using runtime function
1 parent b403d59 commit 77b1871

File tree

2 files changed

+45
-6
lines changed

2 files changed

+45
-6
lines changed

basilisp/compiler.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,6 +1322,32 @@ def _special_form_ast(ctx: CompilerContext,
13221322
raise CompilerException("Special form identified, but not handled") from None
13231323

13241324

1325+
def _resolve_macro_sym(ctx: CompilerContext, form: sym.Symbol) -> Optional[Var]:
1326+
"""Determine if a Basilisp symbol refers to a macro and, if so, return the
1327+
Var it points to.
1328+
1329+
If the symbol cannot be resolved or does not refer to a macro, then this
1330+
function will return None. _sym_ast will generate the AST for a standard
1331+
function call."""
1332+
if form.ns is not None:
1333+
if form.ns == _BUILTINS_NS:
1334+
return None
1335+
elif form.ns == ctx.current_ns.name:
1336+
return ctx.current_ns.find(sym.symbol(form.name))
1337+
ns_sym = sym.symbol(form.ns)
1338+
if ns_sym in ctx.current_ns.imports:
1339+
# We still import Basilisp code, so we'll want to check if
1340+
# the symbol is referring to a Basilisp Var
1341+
return Var.find(form)
1342+
elif ns_sym in ctx.current_ns.aliases:
1343+
aliased_ns = ctx.current_ns.get_alias(ns_sym)
1344+
if aliased_ns:
1345+
return Var.find(sym.symbol(form.name, ns=aliased_ns.name))
1346+
return None
1347+
1348+
return ctx.current_ns.find(form)
1349+
1350+
13251351
def _list_ast(ctx: CompilerContext, form: llist.List) -> ASTStream:
13261352
"""Generate a stream of Python AST nodes for a source code list.
13271353
@@ -1359,12 +1385,7 @@ def _list_ast(ctx: CompilerContext, form: llist.List) -> ASTStream:
13591385

13601386
# Macros are immediately evaluated so the modified form can be compiled
13611387
if isinstance(first, sym.Symbol):
1362-
if first.ns is not None:
1363-
v = Var.find(first)
1364-
else:
1365-
ns_sym = sym.symbol(first.name, ns=ctx.current_ns.name)
1366-
v = Var.find(ns_sym)
1367-
1388+
v = _resolve_macro_sym(ctx, first)
13681389
if v is not None and _is_macro(v):
13691390
try:
13701391
# Call the macro as (f &form & rest)

tests/compiler_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,24 @@ def test_unquote_splicing(ns_var: Var, resolver: reader.Resolver):
693693
assert llist.l(llist.l(reader._UNQUOTE_SPLICING, 53233)) == lcompile("'(~@53233)")
694694

695695

696+
def test_aliased_macro_symbol_resolution(ns_var: Var):
697+
current_ns: runtime.Namespace = ns_var.value
698+
other_ns_name = sym.symbol('other.ns')
699+
try:
700+
other_ns = runtime.Namespace.get_or_create(other_ns_name)
701+
current_ns.add_alias(other_ns_name, other_ns)
702+
current_ns.add_alias(sym.symbol('other'), other_ns)
703+
704+
runtime.set_current_ns(other_ns_name.name)
705+
lcompile("(def ^:macro m (fn* [&form v] v))")
706+
707+
runtime.set_current_ns(current_ns.name)
708+
assert kw.keyword("z") == lcompile("(other.ns/m :z)")
709+
assert kw.keyword("a") == lcompile("(other/m :a)")
710+
finally:
711+
runtime.Namespace.remove(other_ns_name)
712+
713+
696714
def test_var(ns_var: Var):
697715
code = """
698716
(def some-var "a value")

0 commit comments

Comments
 (0)