Skip to content

Commit a88eb51

Browse files
authored
Include the Function or Method context in the AST Node environment (#549)
This PR is a set of otherwise-independent changes I pulled out of #547. It includes the following changes: * The titular Function/Method context stored in the `NodeEnv` object of Lisp AST nodes * Changing the generated module-wide `__NS` variable to `_NS` because I learned that variables prefixed by double underscores are [always mangled](https://bugs.python.org/issue27793) in a class lexical context in Python regardless of the presence of a `global` statement. * Added a few tests for asserting the resolution of `def`'ed Basilisp Vars match the expected behavior from Clojure Fixes #548
1 parent c6eeea9 commit a88eb51

File tree

6 files changed

+141
-53
lines changed

6 files changed

+141
-53
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
* Added support for calling Python functions and methods with keyword arguments (#531)
1111
* Added support for Lisp functions being called with keyword arguments (#528)
1212
* Added support for multi-arity methods on `deftype`s (#534)
13+
* Added metadata about the function or method context of a Lisp AST node in the `NodeEnv` (#548)
1314

1415
### Fixed
1516
* Fixed a bug where the Basilisp AST nodes for return values of `deftype` members could be marked as _statements_ rather than _expressions_, resulting in an incorrect `nil` return (#523)

src/basilisp/lang/compiler/analyzer.py

Lines changed: 58 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
Do,
8888
Fn,
8989
FnArity,
90+
FunctionContext,
9091
HostCall,
9192
HostField,
9293
If,
@@ -309,7 +310,7 @@ def __init__(
309310
) -> None:
310311
self._allow_unresolved_symbols = allow_unresolved_symbols
311312
self._filename = Maybe(filename).or_else_get(DEFAULT_COMPILER_FILE_PATH)
312-
self._func_ctx: Deque[bool] = collections.deque([])
313+
self._func_ctx: Deque[FunctionContext] = collections.deque([])
313314
self._is_quoted: Deque[bool] = collections.deque([])
314315
self._macro_ns: Deque[Optional[runtime.Namespace]] = collections.deque([])
315316
self._opts = (
@@ -406,34 +407,33 @@ def should_macroexpand(self) -> bool:
406407
return self._should_macroexpand
407408

408409
@property
409-
def is_async_ctx(self) -> bool:
410-
"""If True, the current node appears inside of an async function definition.
410+
def func_ctx(self) -> Optional[FunctionContext]:
411+
"""Return the current function or method context of the current node, if one.
412+
Return None otherwise.
413+
411414
It is possible that the current function is defined inside other functions,
412415
so this does not imply anything about the nesting level of the current node."""
413416
try:
414-
return self._func_ctx[-1] is True
417+
return self._func_ctx[-1]
415418
except IndexError:
416-
return False
419+
return None
417420

418421
@property
419-
def in_func_ctx(self) -> bool:
420-
"""If True, the current node appears inside of a function definition.
422+
def is_async_ctx(self) -> bool:
423+
"""Return True if the current node appears inside of an async function
424+
definition. Return False otherwise.
425+
421426
It is possible that the current function is defined inside other functions,
422427
so this does not imply anything about the nesting level of the current node."""
423-
try:
424-
self._func_ctx[-1]
425-
except IndexError:
426-
return False
427-
else:
428-
return True
428+
return self.func_ctx == FunctionContext.ASYNC_FUNCTION
429429

430430
@contextlib.contextmanager
431-
def new_func_ctx(self, is_async: bool = False):
432-
"""Context manager which can be used to set a function context for child
433-
nodes to examine. A new function context is pushed onto the stack each time
434-
the Analyzer finds a new function definition, so there may be many nested
435-
function contexts."""
436-
self._func_ctx.append(is_async)
431+
def new_func_ctx(self, context_type: FunctionContext):
432+
"""Context manager which can be used to set a function or method context for
433+
child nodes to examine. A new function context is pushed onto the stack each
434+
time the Analyzer finds a new function or method definition, so there may be
435+
many nested function contexts."""
436+
self._func_ctx.append(context_type)
437437
yield
438438
self._func_ctx.pop()
439439

@@ -578,8 +578,14 @@ def syntax_position(self) -> NodeSyntacticPosition:
578578
parent node."""
579579
return self._syntax_pos[-1]
580580

581-
def get_node_env(self, pos: Optional[NodeSyntacticPosition] = None):
582-
return NodeEnv(ns=self.current_ns, file=self.filename, pos=pos)
581+
def get_node_env(self, pos: Optional[NodeSyntacticPosition] = None) -> NodeEnv:
582+
"""Return the current Node environment.
583+
584+
If a synax position is given, it will be included in the environment.
585+
Otherwise, the position will be set to None."""
586+
return NodeEnv(
587+
ns=self.current_ns, file=self.filename, pos=pos, func_ctx=self.func_ctx
588+
)
583589

584590

585591
MetaGetter = Callable[[Union[IMeta, Var]], bool]
@@ -810,10 +816,11 @@ def _def_ast( # pylint: disable=too-many-branches,too-many-locals
810816
f"def names must be symbols, not {type(name)}", form=name
811817
)
812818

819+
children: vec.Vector[kw.Keyword]
813820
if nelems == 2:
814821
init = None
815822
doc = None
816-
children: vec.Vector[kw.Keyword] = vec.Vector.empty()
823+
children = vec.Vector.empty()
817824
elif nelems == 3:
818825
with ctx.expr_pos():
819826
init = _analyze_form(ctx, runtime.nth(form, 2))
@@ -895,7 +902,6 @@ def _def_ast( # pylint: disable=too-many-branches,too-many-locals
895902
var=var,
896903
init=init,
897904
doc=doc,
898-
in_func_ctx=ctx.in_func_ctx,
899905
children=children,
900906
env=def_node_env,
901907
)
@@ -1037,7 +1043,7 @@ def __deftype_classmethod(
10371043
has_vargs, fixed_arity, param_nodes = __deftype_method_param_bindings(
10381044
ctx, params
10391045
)
1040-
with ctx.expr_pos():
1046+
with ctx.new_func_ctx(FunctionContext.CLASSMETHOD), ctx.expr_pos():
10411047
stmts, ret = _body_ast(ctx, runtime.nthrest(form, 2))
10421048
method = DefTypeClassMethodArity(
10431049
form=form,
@@ -1097,7 +1103,7 @@ def __deftype_method(
10971103

10981104
loop_id = genname(method_name)
10991105
with ctx.new_recur_point(loop_id, param_nodes):
1100-
with ctx.expr_pos():
1106+
with ctx.new_func_ctx(FunctionContext.METHOD), ctx.expr_pos():
11011107
stmts, ret = _body_ast(ctx, runtime.nthrest(form, 2))
11021108
method = DefTypeMethodArity(
11031109
form=form,
@@ -1160,7 +1166,7 @@ def __deftype_property(
11601166

11611167
assert not has_vargs, "deftype* properties may not have arguments"
11621168

1163-
with ctx.expr_pos():
1169+
with ctx.new_func_ctx(FunctionContext.PROPERTY), ctx.expr_pos():
11641170
stmts, ret = _body_ast(ctx, runtime.nthrest(form, 2))
11651171
prop = DefTypeProperty(
11661172
form=form,
@@ -1192,7 +1198,7 @@ def __deftype_staticmethod(
11921198
"""Emit a node for a :staticmethod member of a deftype* form."""
11931199
with ctx.hide_parent_symbol_table(), ctx.new_symbol_table(method_name):
11941200
has_vargs, fixed_arity, param_nodes = __deftype_method_param_bindings(ctx, args)
1195-
with ctx.expr_pos():
1201+
with ctx.new_func_ctx(FunctionContext.STATICMETHOD), ctx.expr_pos():
11961202
stmts, ret = _body_ast(ctx, runtime.nthrest(form, 2))
11971203
method = DefTypeStaticMethodArity(
11981204
form=form,
@@ -1652,7 +1658,10 @@ def _do_ast(ctx: AnalyzerContext, form: ISeq) -> Do:
16521658

16531659

16541660
def __fn_method_ast( # pylint: disable=too-many-branches,too-many-locals
1655-
ctx: AnalyzerContext, form: ISeq, fnname: Optional[sym.Symbol] = None
1661+
ctx: AnalyzerContext,
1662+
form: ISeq,
1663+
fnname: Optional[sym.Symbol] = None,
1664+
is_async: bool = False,
16561665
) -> FnArity:
16571666
with ctx.new_symbol_table("fn-method"):
16581667
params = form.first
@@ -1711,7 +1720,9 @@ def __fn_method_ast( # pylint: disable=too-many-branches,too-many-locals
17111720

17121721
fn_loop_id = genname("fn_arity" if fnname is None else fnname.name)
17131722
with ctx.new_recur_point(fn_loop_id, param_nodes):
1714-
with ctx.expr_pos():
1723+
with ctx.new_func_ctx(
1724+
FunctionContext.ASYNC_FUNCTION if is_async else FunctionContext.FUNCTION
1725+
), ctx.expr_pos():
17151726
stmts, ret = _body_ast(ctx, form.rest)
17161727
method = FnArity(
17171728
form=form,
@@ -1770,11 +1781,11 @@ def _fn_ast( # pylint: disable=too-many-branches
17701781
form=form,
17711782
)
17721783

1784+
name_node: Optional[Binding]
17731785
if isinstance(name, sym.Symbol):
1774-
name_node: Optional[Binding] = Binding(
1786+
name_node = Binding(
17751787
form=name, name=name.name, local=LocalType.FN, env=ctx.get_node_env()
17761788
)
1777-
assert name_node is not None
17781789
is_async = _is_async(name) or isinstance(form, IMeta) and _is_async(form)
17791790
kwarg_support = (
17801791
__fn_kwargs_support(name)
@@ -1802,23 +1813,24 @@ def _fn_ast( # pylint: disable=too-many-branches
18021813
form=form,
18031814
)
18041815

1805-
with ctx.new_func_ctx(is_async=is_async):
1806-
if isinstance(arity_or_args, llist.List):
1807-
arities = vec.vector(
1808-
map(
1809-
partial(__fn_method_ast, ctx, fnname=name),
1810-
runtime.nthrest(form, idx),
1811-
)
1812-
)
1813-
elif isinstance(arity_or_args, vec.Vector):
1814-
arities = vec.v(
1815-
__fn_method_ast(ctx, runtime.nthrest(form, idx), fnname=name)
1816+
if isinstance(arity_or_args, llist.List):
1817+
arities = vec.vector(
1818+
map(
1819+
partial(__fn_method_ast, ctx, fnname=name, is_async=is_async),
1820+
runtime.nthrest(form, idx),
18161821
)
1817-
else:
1818-
raise AnalyzerException(
1819-
"fn form must match: (fn* name? [arg*] body*) or (fn* name? method*)",
1820-
form=form,
1822+
)
1823+
elif isinstance(arity_or_args, vec.Vector):
1824+
arities = vec.v(
1825+
__fn_method_ast(
1826+
ctx, runtime.nthrest(form, idx), fnname=name, is_async=is_async
18211827
)
1828+
)
1829+
else:
1830+
raise AnalyzerException(
1831+
"fn form must match: (fn* name? [arg*] body*) or (fn* name? method*)",
1832+
form=form,
1833+
)
18221834

18231835
nmethods = count(arities)
18241836
assert nmethods > 0, "fn must have at least one arity"

src/basilisp/lang/compiler/generator.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@
131131
_SET_BANG_TEMP_PREFIX = "set_bang_val"
132132
_THROW_PREFIX = "lisp_throw"
133133
_TRY_PREFIX = "lisp_try"
134-
_NS_VAR = "__NS"
134+
_NS_VAR = "_NS"
135135

136136

137137
GeneratorException = partial(CompilerException, phase=CompilerPhase.CODE_GENERATION)
@@ -750,15 +750,19 @@ def _def_to_py_ast( # pylint: disable=too-many-branches
750750
# complaining that we assign the value prior to global declaration.
751751
def_dependencies = list(
752752
chain(
753-
[ast.Global(names=[safe_name])] if node.in_func_ctx else [],
753+
[ast.Global(names=[safe_name])]
754+
if node.env.func_ctx is not None
755+
else [],
754756
def_ast.dependencies,
755757
)
756758
)
757759
else:
758760
def_dependencies = list(
759761
chain(
760762
def_ast.dependencies,
761-
[ast.Global(names=[safe_name])] if node.in_func_ctx else [],
763+
[ast.Global(names=[safe_name])]
764+
if node.env.func_ctx is not None
765+
else [],
762766
[
763767
ast.Assign(
764768
targets=[ast.Name(id=safe_name, ctx=ast.Store())],
@@ -775,7 +779,9 @@ def _def_to_py_ast( # pylint: disable=too-many-branches
775779
# root.
776780
func = _INTERN_UNBOUND_VAR_FN_NAME
777781
extra_args = []
778-
def_dependencies = [ast.Global(names=[safe_name])] if node.in_func_ctx else []
782+
def_dependencies = (
783+
[ast.Global(names=[safe_name])] if node.env.func_ctx is not None else []
784+
)
779785

780786
meta_ast = gen_py_ast(ctx, node.meta)
781787

src/basilisp/lang/compiler/nodes.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,15 @@ class ConstType(Enum):
272272
LoopID = str
273273

274274

275+
class FunctionContext(Enum):
276+
FUNCTION = kw.keyword("function")
277+
ASYNC_FUNCTION = kw.keyword("async-function")
278+
METHOD = kw.keyword("method")
279+
CLASSMETHOD = kw.keyword("classmethod")
280+
STATICMETHOD = kw.keyword("staticmethod")
281+
PROPERTY = kw.keyword("property")
282+
283+
275284
class KeywordArgSupport(Enum):
276285
APPLY_KWARGS = kw.keyword("apply")
277286
COLLECT_KWARGS = kw.keyword("collect")
@@ -296,6 +305,7 @@ class NodeEnv:
296305
line: Optional[int] = None
297306
col: Optional[int] = None
298307
pos: Optional[NodeSyntacticPosition] = None
308+
func_ctx: Optional[FunctionContext] = None
299309

300310

301311
@attr.s(auto_attribs=True, frozen=True, slots=True)
@@ -360,7 +370,6 @@ class Def(Node[SpecialForm]):
360370
var: Var
361371
init: Optional[Node]
362372
doc: Optional[str]
363-
in_func_ctx: bool
364373
env: NodeEnv
365374
meta: NodeMeta = None
366375
children: Sequence[kw.Keyword] = vec.Vector.empty()

src/basilisp/lang/runtime.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1656,13 +1656,13 @@ def init_ns_var() -> Var:
16561656
def bootstrap_core(compiler_opts: CompilerOpts) -> None:
16571657
"""Bootstrap the environment with functions that are either difficult to express
16581658
with the very minimal Lisp environment or which are expected by the compiler."""
1659-
__NS = Maybe(Var.find(NS_VAR_SYM)).or_else_raise(
1659+
_NS = Maybe(Var.find(NS_VAR_SYM)).or_else_raise(
16601660
lambda: RuntimeException(f"Dynamic Var {NS_VAR_SYM} not bound!")
16611661
)
16621662

16631663
def in_ns(s: sym.Symbol):
16641664
ns = Namespace.get_or_create(s)
1665-
__NS.value = ns
1665+
_NS.value = ns
16661666
return ns
16671667

16681668
# Vars used in bootstrapping the runtime

tests/basilisp/compiler_test.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4264,6 +4264,66 @@ def test_nested_bare_sym_will_not_resolve(self, lcompile: CompileFn):
42644264
with pytest.raises(compiler.CompilerException):
42654265
lcompile("basilisp.lang.map.MapEntry.of")
42664266

4267+
@pytest.mark.parametrize(
4268+
"code",
4269+
[
4270+
"""
4271+
(fn []
4272+
(def a :a)
4273+
(var a))""",
4274+
"""
4275+
(import* collections.abc)
4276+
(deftype* Definer []
4277+
:implements [collections.abc/Callable]
4278+
(--call-- [this] (def a :a) (var a)))
4279+
(Definer)""",
4280+
],
4281+
)
4282+
def test_symbol_deffed_in_fn_or_method_will_resolve_in_fn_or_method(
4283+
self, ns: runtime.Namespace, lcompile: CompileFn, code: str,
4284+
):
4285+
# This behavior is peculiar and perhaps even _wrong_, but it matches how
4286+
# Clojure treats Vars defined in functions. Of course, generally speaking,
4287+
# Vars should not be defined like this so I suppose it's not a huge deal.
4288+
fn = lcompile(code)
4289+
4290+
resolved_var = ns.find(sym.symbol("a"))
4291+
assert not resolved_var.is_bound
4292+
4293+
returned_var = fn()
4294+
assert returned_var is resolved_var
4295+
assert returned_var.is_bound
4296+
assert returned_var.value == kw.keyword("a")
4297+
4298+
@pytest.mark.parametrize(
4299+
"code",
4300+
[
4301+
"""
4302+
(do
4303+
(fn [] (def a :a))
4304+
(var a))""",
4305+
"""
4306+
(fn [] (def a :a))
4307+
(var a)""",
4308+
"""
4309+
(import* collections.abc)
4310+
(deftype* Definer []
4311+
:implements [collections.abc/Callable]
4312+
(--call-- [this] (def a :a)))
4313+
(var a)""",
4314+
],
4315+
)
4316+
def test_symbol_deffed_in_fn_or_method_will_resolve_outside_fn_or_method(
4317+
self, ns: runtime.Namespace, lcompile: CompileFn, code: str
4318+
):
4319+
var = lcompile(code)
4320+
assert not var.is_bound
4321+
4322+
resolved_var = ns.find(sym.symbol("a"))
4323+
assert not resolved_var.is_bound
4324+
4325+
assert var is resolved_var
4326+
42674327
def test_local_deftype_classmethod_resolves(self, lcompile: CompileFn):
42684328
Point = lcompile(
42694329
"""

0 commit comments

Comments
 (0)