Skip to content

Commit 1d23cbc

Browse files
authored
Fix transient macro namespace resolution (#438)
* Fix transient namespace resolution * Formatting * Add a test for referred names as well * Additional test cases (which are currently broken) * Plz fix * Resolve some remaining issues * Satisfy PyLint
1 parent 1a35edc commit 1d23cbc

File tree

6 files changed

+253
-34
lines changed

6 files changed

+253
-34
lines changed

src/basilisp/core/template.lpy

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
(ns basilisp.core.template)
1+
(ns basilisp.core.template
2+
(:require
3+
[basilisp.walk :refer [postwalk-replace]]))
24

35
(defmacro do-template
46
"Given a template expression expr and bindings, produce a do expression with
@@ -23,6 +25,6 @@
2325
(as-> arg-group $
2426
(interleave argv $)
2527
(apply hash-map $)
26-
(replace $ expr)))]
28+
(postwalk-replace $ expr)))]
2729
`(do
2830
~@(map template-expr arg-groups))) )

src/basilisp/lang/compiler/analyzer.py

Lines changed: 137 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ class AnalyzerContext:
281281
"_filename",
282282
"_func_ctx",
283283
"_is_quoted",
284+
"_macro_ns",
284285
"_opts",
285286
"_recur_points",
286287
"_should_macroexpand",
@@ -298,6 +299,7 @@ def __init__(
298299
self._filename = Maybe(filename).or_else_get(DEFAULT_COMPILER_FILE_PATH)
299300
self._func_ctx: Deque[bool] = collections.deque([])
300301
self._is_quoted: Deque[bool] = collections.deque([])
302+
self._macro_ns: Deque[Optional[runtime.Namespace]] = collections.deque([])
301303
self._opts = (
302304
Maybe(opts).map(lmap.map).or_else_get(lmap.Map.empty()) # type: ignore
303305
)
@@ -348,6 +350,35 @@ def quoted(self):
348350
yield
349351
self._is_quoted.pop()
350352

353+
@property
354+
def current_macro_ns(self) -> Optional[runtime.Namespace]:
355+
"""Return the current transient namespace available during macroexpansion.
356+
357+
If None, the analyzer should only use the current namespace for symbol
358+
resolution."""
359+
try:
360+
return self._macro_ns[-1]
361+
except IndexError:
362+
return None
363+
364+
@contextlib.contextmanager
365+
def macro_ns(self, ns: Optional[runtime.Namespace]):
366+
"""Set the transient namespace which is available to the analyer during a
367+
macroexpansion phase.
368+
369+
If set to None, prohibit the analyzer from using another namespace for symbol
370+
resolution.
371+
372+
During macroexpansion, new forms referenced from the macro namespace would
373+
be unavailable to the namespace containing the original macro invocation.
374+
The macro namespace is a temporary override pointing to the namespace of the
375+
macro definition which can be used to resolve these transient references."""
376+
self._macro_ns.append(ns)
377+
try:
378+
yield
379+
finally:
380+
self._macro_ns.pop()
381+
351382
@property
352383
def should_allow_unresolved_symbols(self) -> bool:
353384
"""If True, the analyzer will allow unresolved symbols. This is primarily
@@ -1639,7 +1670,10 @@ def _invoke_ast(ctx: AnalyzerContext, form: Union[llist.List, ISeq]) -> Node:
16391670
try:
16401671
macro_env = ctx.symbol_table.as_env_map()
16411672
expanded = fn.var.value(macro_env, form, *form.rest)
1642-
expanded_ast = _analyze_form(ctx, expanded)
1673+
with ctx.macro_ns(
1674+
fn.var.ns if fn.var.ns is not ctx.current_ns else None
1675+
):
1676+
expanded_ast = _analyze_form(ctx, expanded)
16431677

16441678
# Verify that macroexpanded code also does not have any
16451679
# non-tail recur forms
@@ -2085,34 +2119,60 @@ def _resolve_nested_symbol(ctx: AnalyzerContext, form: sym.Symbol) -> HostField:
20852119
)
20862120

20872121

2088-
def __resolve_namespaced_symbol( # pylint: disable=too-many-branches
2089-
ctx: AnalyzerContext, form: sym.Symbol
2090-
) -> Union[Const, HostField, MaybeClass, MaybeHostForm, VarRef]:
2091-
"""Resolve a namespaced symbol into a Python name or Basilisp Var."""
2122+
def __fuzzy_resolve_namespace_reference(
2123+
ctx: AnalyzerContext, which_ns: runtime.Namespace, form: sym.Symbol
2124+
) -> Optional[VarRef]:
2125+
"""Resolve a symbol within `which_ns` based on any namespaces required or otherwise
2126+
referenced within `which_ns` (e.g. by a :refer).
2127+
2128+
When a required or resolved symbol is read by the reader in the context of a syntax
2129+
quote, the reader will fully resolve the symbol, so a symbol like `set/union` would be
2130+
expanded to `basilisp.set/union`. However, the namespace still does not maintain a
2131+
direct mapping of the symbol `basilisp.set` to the namespace it names, since the
2132+
namespace was required as `[basilisp.set :as set]`.
2133+
2134+
During macroexpansion, the Analyzer needs to resolve these transitive requirements,
2135+
so we 'fuzzy' resolve against any namespaces known to the current macro namespace."""
20922136
assert form.ns is not None
2137+
ns_name = form.ns
2138+
2139+
def resolve_ns_reference(
2140+
ns_map: Mapping[str, runtime.Namespace]
2141+
) -> Optional[VarRef]:
2142+
match: Optional[runtime.Namespace] = ns_map.get(ns_name)
2143+
if match is not None:
2144+
v = match.find(sym.symbol(form.name))
2145+
if v is not None:
2146+
return VarRef(form=form, var=v, env=ctx.get_node_env())
2147+
return None
20932148

2094-
if form.ns == ctx.current_ns.name:
2095-
v = ctx.current_ns.find(sym.symbol(form.name))
2096-
if v is not None:
2097-
return VarRef(form=form, var=v, env=ctx.get_node_env())
2098-
elif form.ns == _BUILTINS_NS:
2099-
class_ = munge(form.name, allow_builtins=True)
2100-
target = getattr(builtins, class_, None)
2101-
if target is None:
2102-
raise AnalyzerException(
2103-
f"cannot resolve builtin function '{class_}'", form=form
2104-
)
2105-
return MaybeClass(
2106-
form=form, class_=class_, target=target, env=ctx.get_node_env()
2107-
)
2149+
# Try to match a required namespace
2150+
required_namespaces = {ns.name: ns for ns in which_ns.aliases.values()}
2151+
match = resolve_ns_reference(required_namespaces)
2152+
if match is not None:
2153+
return match
21082154

2109-
if "." in form.name and form.name != _DOUBLE_DOT_MACRO_NAME:
2110-
raise AnalyzerException(
2111-
"symbol names may not contain the '.' operator", form=form
2112-
)
2155+
# Try to match a referred namespace
2156+
referred_namespaces = {
2157+
ns.name: ns for ns in {var.ns for var in which_ns.refers.values()}
2158+
}
2159+
return resolve_ns_reference(referred_namespaces)
2160+
2161+
2162+
def __resolve_namespaced_symbol_in_ns( # pylint: disable=too-many-branches
2163+
ctx: AnalyzerContext,
2164+
which_ns: runtime.Namespace,
2165+
form: sym.Symbol,
2166+
allow_fuzzy_macroexpansion_matching: bool = False,
2167+
) -> Optional[Union[MaybeHostForm, VarRef]]:
2168+
"""Resolve the symbol `form` in the context of the Namespace `which_ns`. If
2169+
`allow_fuzzy_macroexpansion_matching` is True and no match is made on existing
2170+
imports, import aliases, or namespace aliases, then attempt to match the
2171+
namespace portion"""
2172+
assert form.ns is not None
21132173

21142174
ns_sym = sym.symbol(form.ns)
2115-
if ns_sym in ctx.current_ns.imports or ns_sym in ctx.current_ns.import_aliases:
2175+
if ns_sym in which_ns.imports or ns_sym in which_ns.import_aliases:
21162176
# We still import Basilisp code, so we'll want to make sure
21172177
# that the symbol isn't referring to a Basilisp Var first
21182178
v = Var.find(form)
@@ -2123,8 +2183,8 @@ def __resolve_namespaced_symbol( # pylint: disable=too-many-branches
21232183
# We don't need this for actually generating the link later, but
21242184
# we _do_ need it for fetching a reference to the module to check
21252185
# for membership.
2126-
if ns_sym in ctx.current_ns.import_aliases:
2127-
ns = ctx.current_ns.import_aliases[ns_sym]
2186+
if ns_sym in which_ns.import_aliases:
2187+
ns = which_ns.import_aliases[ns_sym]
21282188
assert ns is not None
21292189
ns_name = ns.name
21302190
else:
@@ -2161,16 +2221,58 @@ def __resolve_namespaced_symbol( # pylint: disable=too-many-branches
21612221
target=vars(ns_module)[safe_name],
21622222
env=ctx.get_node_env(),
21632223
)
2164-
elif ns_sym in ctx.current_ns.aliases:
2165-
aliased_ns: runtime.Namespace = ctx.current_ns.aliases[ns_sym]
2224+
elif ns_sym in which_ns.aliases:
2225+
aliased_ns: runtime.Namespace = which_ns.aliases[ns_sym]
21662226
v = Var.find(sym.symbol(form.name, ns=aliased_ns.name))
21672227
if v is None:
21682228
raise AnalyzerException(
21692229
f"unable to resolve symbol '{sym.symbol(form.name, ns_sym.name)}' in this context",
21702230
form=form,
21712231
)
21722232
return VarRef(form=form, var=v, env=ctx.get_node_env())
2173-
elif "." in form.ns:
2233+
elif allow_fuzzy_macroexpansion_matching:
2234+
return __fuzzy_resolve_namespace_reference(ctx, which_ns, form)
2235+
2236+
return None
2237+
2238+
2239+
def __resolve_namespaced_symbol( # pylint: disable=too-many-branches
2240+
ctx: AnalyzerContext, form: sym.Symbol
2241+
) -> Union[Const, HostField, MaybeClass, MaybeHostForm, VarRef]:
2242+
"""Resolve a namespaced symbol into a Python name or Basilisp Var."""
2243+
assert form.ns is not None
2244+
2245+
if form.ns == ctx.current_ns.name:
2246+
v = ctx.current_ns.find(sym.symbol(form.name))
2247+
if v is not None:
2248+
return VarRef(form=form, var=v, env=ctx.get_node_env())
2249+
elif form.ns == _BUILTINS_NS:
2250+
class_ = munge(form.name, allow_builtins=True)
2251+
target = getattr(builtins, class_, None)
2252+
if target is None:
2253+
raise AnalyzerException(
2254+
f"cannot resolve builtin function '{class_}'", form=form
2255+
)
2256+
return MaybeClass(
2257+
form=form, class_=class_, target=target, env=ctx.get_node_env()
2258+
)
2259+
2260+
if "." in form.name and form.name != _DOUBLE_DOT_MACRO_NAME:
2261+
raise AnalyzerException(
2262+
"symbol names may not contain the '.' operator", form=form
2263+
)
2264+
2265+
resolved = __resolve_namespaced_symbol_in_ns(ctx, ctx.current_ns, form)
2266+
if resolved is not None:
2267+
return resolved
2268+
elif ctx.current_macro_ns is not None:
2269+
resolved = __resolve_namespaced_symbol_in_ns(
2270+
ctx, ctx.current_macro_ns, form, allow_fuzzy_macroexpansion_matching=True
2271+
)
2272+
if resolved is not None:
2273+
return resolved
2274+
2275+
if "." in form.ns:
21742276
return _resolve_nested_symbol(ctx, form)
21752277
elif ctx.should_allow_unresolved_symbols:
21762278
return _const_node(ctx, form)
@@ -2192,6 +2294,12 @@ def __resolve_bare_symbol(
21922294
if v is not None:
21932295
return VarRef(form=form, var=v, env=ctx.get_node_env())
21942296

2297+
# Look up the symbol in the current macro namespace, if one
2298+
if ctx.current_macro_ns is not None:
2299+
v = ctx.current_macro_ns.find(form)
2300+
if v is not None:
2301+
return VarRef(form=form, var=v, env=ctx.get_node_env())
2302+
21952303
if "." in form.name:
21962304
raise AnalyzerException(
21972305
"symbol names may not contain the '.' operator", form=form

src/basilisp/lang/interfaces.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ def __eq__(self, other):
223223
return False
224224
return True
225225

226+
def __hash__(self):
227+
return hash(tuple(self))
228+
226229
def __iter__(self):
227230
o = self
228231
if o:

src/basilisp/test.lpy

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
(:import
33
inspect)
44
(:require
5-
[basilisp.core.template :refer [do-template]]))
5+
[basilisp.core.template :as template]))
66

77
(def ^:private collected-tests
88
(atom []))
@@ -115,7 +115,7 @@
115115
(defmacro are
116116
"Assert that expr is true. Must appear inside of a deftest form."
117117
[argv expr & args]
118-
`(do-template ~argv (is ~expr) ~@args))
118+
`(template/do-template ~argv (is ~expr) ~@args))
119119

120120
(defmacro testing
121121
"Wrapper for test cases to provide additional messaging and context

src/basilisp/testrunner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def runtest(self):
153153
if runtime.to_seq(failures):
154154
raise TestFailuresInfo("Test failures", lmap.map(results))
155155

156-
def repr_failure(self, excinfo):
156+
def repr_failure(self, excinfo, style=None): # pylint: disable=unused-argument
157157
"""Representation function called when self.runtest() raises an
158158
exception."""
159159
if isinstance(excinfo.value, TestFailuresInfo):

tests/basilisp/compiler_test.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2877,6 +2877,112 @@ def test_aliased_macro_symbol_resolution(self, ns: runtime.Namespace):
28772877
finally:
28782878
runtime.Namespace.remove(other_ns_name)
28792879

2880+
def test_cross_ns_macro_symbol_resolution(self, ns: runtime.Namespace):
2881+
"""Ensure that a macro symbol, `a`, delegating to another macro, named
2882+
by the symbol `b`, in a namespace directly required by `a`'s namespace
2883+
(and which will not be required by downstream namespaces) is still
2884+
properly resolved when used by the final consumer."""
2885+
current_ns: runtime.Namespace = ns
2886+
other_ns_name = sym.symbol("other.ns")
2887+
third_ns_name = sym.symbol("third.ns")
2888+
try:
2889+
other_ns = runtime.Namespace.get_or_create(other_ns_name)
2890+
current_ns.add_alias(other_ns_name, other_ns)
2891+
2892+
third_ns = runtime.Namespace.get_or_create(third_ns_name)
2893+
other_ns.add_alias(third_ns_name, third_ns)
2894+
2895+
with runtime.ns_bindings(third_ns_name.name):
2896+
lcompile(
2897+
"(def ^:macro t (fn* [&env &form v] `(name ~v)))",
2898+
resolver=runtime.resolve_alias,
2899+
)
2900+
2901+
with runtime.ns_bindings(other_ns_name.name):
2902+
lcompile(
2903+
"(def ^:macro o (fn* [&env &form v] `(third.ns/t ~v)))",
2904+
resolver=runtime.resolve_alias,
2905+
)
2906+
2907+
with runtime.ns_bindings(current_ns.name):
2908+
assert "z" == lcompile(
2909+
"(other.ns/o :z)", resolver=runtime.resolve_alias
2910+
)
2911+
finally:
2912+
runtime.Namespace.remove(other_ns_name)
2913+
runtime.Namespace.remove(third_ns_name)
2914+
2915+
def test_cross_ns_macro_symbol_resolution_with_aliases(self, ns: runtime.Namespace):
2916+
"""Ensure that `a` macro symbol, a, delegating to another macro, named
2917+
by the symbol `b`, which is referenced by `a`'s namespace (and which will
2918+
not be referred by downstream namespaces) is still properly resolved
2919+
when used by the final consumer."""
2920+
current_ns: runtime.Namespace = ns
2921+
other_ns_name = sym.symbol("other.ns")
2922+
third_ns_name = sym.symbol("third.ns")
2923+
try:
2924+
other_ns = runtime.Namespace.get_or_create(other_ns_name)
2925+
current_ns.add_alias(other_ns_name, other_ns)
2926+
2927+
third_ns = runtime.Namespace.get_or_create(third_ns_name)
2928+
other_ns.add_alias(sym.symbol("third"), third_ns)
2929+
2930+
with runtime.ns_bindings(third_ns_name.name):
2931+
lcompile(
2932+
"(def ^:macro t (fn* [&env &form v] `(name ~v)))",
2933+
resolver=runtime.resolve_alias,
2934+
)
2935+
2936+
with runtime.ns_bindings(other_ns_name.name):
2937+
lcompile(
2938+
"(def ^:macro o (fn* [&env &form v] `(third/t ~v)))",
2939+
resolver=runtime.resolve_alias,
2940+
)
2941+
2942+
with runtime.ns_bindings(current_ns.name):
2943+
assert "z" == lcompile(
2944+
"(other.ns/o :z)", resolver=runtime.resolve_alias
2945+
)
2946+
finally:
2947+
runtime.Namespace.remove(other_ns_name)
2948+
runtime.Namespace.remove(third_ns_name)
2949+
2950+
def test_cross_ns_macro_symbol_resolution_with_refers(self, ns: runtime.Namespace):
2951+
"""Ensure that a macro symbol, `a`, delegating to another macro, named
2952+
by the symbol `b`, which is referred by `a`'s namespace (and which will
2953+
not be referred by downstream namespaces) is still properly resolved
2954+
when used by the final consumer."""
2955+
current_ns: runtime.Namespace = ns
2956+
other_ns_name = sym.symbol("other.ns")
2957+
third_ns_name = sym.symbol("third.ns")
2958+
try:
2959+
other_ns = runtime.Namespace.get_or_create(other_ns_name)
2960+
current_ns.add_alias(other_ns_name, other_ns)
2961+
2962+
third_ns = runtime.Namespace.get_or_create(third_ns_name)
2963+
2964+
with runtime.ns_bindings(third_ns_name.name):
2965+
lcompile(
2966+
"(def ^:macro t (fn* [&env &form v] `(name ~v)))",
2967+
resolver=runtime.resolve_alias,
2968+
)
2969+
2970+
other_ns.add_refer(sym.symbol("t"), third_ns.find(sym.symbol("t")))
2971+
2972+
with runtime.ns_bindings(other_ns_name.name):
2973+
lcompile(
2974+
"(def ^:macro o (fn* [&env &form v] `(t ~v)))",
2975+
resolver=runtime.resolve_alias,
2976+
)
2977+
2978+
with runtime.ns_bindings(current_ns.name):
2979+
assert "z" == lcompile(
2980+
"(other.ns/o :z)", resolver=runtime.resolve_alias
2981+
)
2982+
finally:
2983+
runtime.Namespace.remove(other_ns_name)
2984+
runtime.Namespace.remove(third_ns_name)
2985+
28802986

28812987
class TestWarnOnVarIndirection:
28822988
@pytest.fixture

0 commit comments

Comments
 (0)