From 48138ab6adf42cc3ace418e3e7232247e600aaa9 Mon Sep 17 00:00:00 2001 From: Chris Rink Date: Wed, 9 Apr 2025 21:07:52 -0400 Subject: [PATCH 1/2] Fix defining async generators --- CHANGELOG.md | 5 +- src/basilisp/lang/compiler/analyzer.py | 133 ++++++++++++++---------- src/basilisp/lang/compiler/generator.py | 12 ++- src/basilisp/lang/compiler/nodes.py | 92 ++++++++-------- 4 files changed, 142 insertions(+), 100 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0009f43e1..de9811ecc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Changed - * Single arity functions can be tagged with `^:allow-unsafe-names` to preserve their parameter names (#1212) + * Single arity functions can be tagged with `^:allow-unsafe-names` to preserve their parameter names (#1212) + +### Fixed + * Fix an issue where the compiler would generate an illegal `return` statement for asynchronous generators (#1180) ## [v0.3.7] ### Fixed diff --git a/src/basilisp/lang/compiler/analyzer.py b/src/basilisp/lang/compiler/analyzer.py index 956df14c7..b7c4e3e3e 100644 --- a/src/basilisp/lang/compiler/analyzer.py +++ b/src/basilisp/lang/compiler/analyzer.py @@ -10,7 +10,14 @@ import sys import uuid from collections import defaultdict -from collections.abc import Collection, Iterable, Mapping, MutableMapping, MutableSet +from collections.abc import ( + Collection, + Iterable, + Iterator, + Mapping, + MutableMapping, + MutableSet, +) from datetime import datetime from decimal import Decimal from fractions import Fraction @@ -89,6 +96,7 @@ Fn, FnArity, FunctionContext, + FunctionContextType, HostCall, HostField, If, @@ -447,16 +455,23 @@ def is_async_ctx(self) -> bool: It is possible that the current function is defined inside other functions, so this does not imply anything about the nesting level of the current node.""" - return self.func_ctx == FunctionContext.ASYNC_FUNCTION + func_ctx = self.func_ctx + return ( + func_ctx is not None + and func_ctx.function_type == FunctionContextType.ASYNC_FUNCTION + ) @contextlib.contextmanager - def new_func_ctx(self, context_type: FunctionContext): + def new_func_ctx( + self, context_type: FunctionContextType + ) -> Iterator[FunctionContext]: """Context manager which can be used to set a function or method context for child nodes to examine. A new function context is pushed onto the stack each time the Analyzer finds a new function or method definition, so there may be many nested function contexts.""" - self._func_ctx.append(context_type) - yield + func_ctx = FunctionContext(context_type) + self._func_ctx.append(func_ctx) + yield func_ctx self._func_ctx.pop() @property @@ -1186,7 +1201,7 @@ def __deftype_method_param_bindings( return has_vargs, fixed_arity, param_nodes -def __deftype_classmethod( +def __deftype_classmethod( # pylint: disable=too-many-locals form: Union[llist.PersistentList, ISeq], ctx: AnalyzerContext, method_name: str, @@ -1222,16 +1237,9 @@ def __deftype_classmethod( has_vargs, fixed_arity, param_nodes = __deftype_method_param_bindings( params, ctx, SpecialForm.DEFTYPE ) - with ctx.new_func_ctx(FunctionContext.CLASSMETHOD), ctx.expr_pos(): + with ctx.new_func_ctx(FunctionContextType.CLASSMETHOD), ctx.expr_pos(): stmts, ret = _body_ast(runtime.nthrest(form, 2), ctx) - method = DefTypeClassMethod( - form=form, - name=method_name, - params=vec.vector(param_nodes), - fixed_arity=fixed_arity, - is_variadic=has_vargs, - kwarg_support=kwarg_support, - body=Do( + body = Do( form=form.rest, statements=vec.vector(stmts), ret=ret, @@ -1239,7 +1247,15 @@ def __deftype_classmethod( # Use the argument vector or first body statement, whichever # exists, for metadata. env=ctx.get_node_env(), - ), + ) + method = DefTypeClassMethod( + form=form, + name=method_name, + params=vec.vector(param_nodes), + fixed_arity=fixed_arity, + is_variadic=has_vargs, + kwarg_support=kwarg_support, + body=body, class_local=cls_binding, env=ctx.get_node_env(), ) @@ -1286,8 +1302,17 @@ def __deftype_or_reify_method( # pylint: disable=too-many-arguments,too-many-lo loop_id = genname(method_name) with ctx.new_recur_point(loop_id, param_nodes): - with ctx.new_func_ctx(FunctionContext.METHOD), ctx.expr_pos(): + with ctx.new_func_ctx(FunctionContextType.METHOD), ctx.expr_pos(): stmts, ret = _body_ast(runtime.nthrest(form, 2), ctx) + body = Do( + form=form.rest, + statements=vec.vector(stmts), + ret=ret, + is_body=True, + # Use the argument vector or first body statement, whichever + # exists, for metadata. + env=ctx.get_node_env(), + ) method = DefTypeMethodArity( form=form, name=method_name, @@ -1296,15 +1321,7 @@ def __deftype_or_reify_method( # pylint: disable=too-many-arguments,too-many-lo fixed_arity=fixed_arity, is_variadic=has_vargs, kwarg_support=kwarg_support, - body=Do( - form=form.rest, - statements=vec.vector(stmts), - ret=ret, - is_body=True, - # Use the argument vector or first body statement, whichever - # exists, for metadata. - env=ctx.get_node_env(), - ), + body=body, loop_id=loop_id, env=ctx.get_node_env(), ) @@ -1356,14 +1373,9 @@ def __deftype_or_reify_property( assert not has_vargs, f"{special_form} properties may not have arguments" - with ctx.new_func_ctx(FunctionContext.PROPERTY), ctx.expr_pos(): + with ctx.new_func_ctx(FunctionContextType.PROPERTY), ctx.expr_pos(): stmts, ret = _body_ast(runtime.nthrest(form, 2), ctx) - prop = DefTypeProperty( - form=form, - name=method_name, - this_local=this_binding, - params=vec.vector(param_nodes), - body=Do( + body = Do( form=form.rest, statements=vec.vector(stmts), ret=ret, @@ -1371,7 +1383,13 @@ def __deftype_or_reify_property( # Use the argument vector or first body statement, whichever # exists, for metadata. env=ctx.get_node_env(), - ), + ) + prop = DefTypeProperty( + form=form, + name=method_name, + this_local=this_binding, + params=vec.vector(param_nodes), + body=body, env=ctx.get_node_env(), ) prop.visit(partial(_assert_no_recur, ctx)) @@ -1393,16 +1411,9 @@ def __deftype_staticmethod( has_vargs, fixed_arity, param_nodes = __deftype_method_param_bindings( args, ctx, SpecialForm.DEFTYPE ) - with ctx.new_func_ctx(FunctionContext.STATICMETHOD), ctx.expr_pos(): + with ctx.new_func_ctx(FunctionContextType.STATICMETHOD), ctx.expr_pos(): stmts, ret = _body_ast(runtime.nthrest(form, 2), ctx) - method = DefTypeStaticMethod( - form=form, - name=method_name, - params=vec.vector(param_nodes), - fixed_arity=fixed_arity, - is_variadic=has_vargs, - kwarg_support=kwarg_support, - body=Do( + body = Do( form=form.rest, statements=vec.vector(stmts), ret=ret, @@ -1410,7 +1421,15 @@ def __deftype_staticmethod( # Use the argument vector or first body statement, whichever # exists, for metadata. env=ctx.get_node_env(), - ), + ) + method = DefTypeStaticMethod( + form=form, + name=method_name, + params=vec.vector(param_nodes), + fixed_arity=fixed_arity, + is_variadic=has_vargs, + kwarg_support=kwarg_support, + body=body, env=ctx.get_node_env(), ) method.visit(partial(_assert_no_recur, ctx)) @@ -2092,21 +2111,14 @@ def __fn_method_ast( # pylint: disable=too-many-locals with ctx.new_recur_point(fn_loop_id, param_nodes): with ( ctx.new_func_ctx( - FunctionContext.ASYNC_FUNCTION + FunctionContextType.ASYNC_FUNCTION if is_async - else FunctionContext.FUNCTION + else FunctionContextType.FUNCTION ), ctx.expr_pos(), ): stmts, ret = _body_ast(form.rest, ctx) - method = FnArity( - form=form, - loop_id=fn_loop_id, - params=vec.vector(param_nodes), - tag=return_tag, - is_variadic=has_vargs, - fixed_arity=len(param_nodes) - int(has_vargs), - body=Do( + body = Do( form=form.rest, statements=vec.vector(stmts), ret=ret, @@ -2114,7 +2126,15 @@ def __fn_method_ast( # pylint: disable=too-many-locals # Use the argument vector or first body statement, whichever # exists, for metadata. env=ctx.get_node_env(), - ), + ) + method = FnArity( + form=form, + loop_id=fn_loop_id, + params=vec.vector(param_nodes), + tag=return_tag, + is_variadic=has_vargs, + fixed_arity=len(param_nodes) - int(has_vargs), + body=body, # Use the argument vector for fetching line/col since the # form itself is a sequence with no meaningful metadata. env=ctx.get_node_env(), @@ -3420,6 +3440,9 @@ def _yield_ast(form: ISeq, ctx: AnalyzerContext) -> Yield: "yield forms must contain 1 or 2 elements, as in: (yield [expr])", form=form ) + # Indicate that the current function is a generator + ctx.func_ctx.is_generator = True + if nelems == 2: with ctx.expr_pos(): expr = _analyze_form(runtime.nth(form, 1), ctx) diff --git a/src/basilisp/lang/compiler/generator.py b/src/basilisp/lang/compiler/generator.py index 22a289469..f14a01c39 100644 --- a/src/basilisp/lang/compiler/generator.py +++ b/src/basilisp/lang/compiler/generator.py @@ -60,6 +60,7 @@ Do, Fn, FnArity, + FunctionContextType, HostCall, HostField, If, @@ -1715,7 +1716,16 @@ def __fn_args_to_py_ast( body_ast = _synthetic_do_to_py_ast(ctx, body) fn_body_ast.extend(map(statementize, body_ast.dependencies)) - fn_body_ast.append(ast.Return(value=body_ast.node)) + + func_ctx = body.env.func_ctx + if ( + func_ctx is not None + and func_ctx.is_generator + and func_ctx.function_type == FunctionContextType.ASYNC_FUNCTION + ): + fn_body_ast.append(ast.Pass()) + else: + fn_body_ast.append(ast.Return(value=body_ast.node)) return fn_args, varg, fn_body_ast, fn_def_deps diff --git a/src/basilisp/lang/compiler/nodes.py b/src/basilisp/lang/compiler/nodes.py index e60399d31..aee5a44e6 100644 --- a/src/basilisp/lang/compiler/nodes.py +++ b/src/basilisp/lang/compiler/nodes.py @@ -293,7 +293,7 @@ class ConstType(Enum): LoopID = str -class FunctionContext(Enum): +class FunctionContextType(Enum): FUNCTION = kw.keyword("function") ASYNC_FUNCTION = kw.keyword("async-function") METHOD = kw.keyword("method") @@ -302,6 +302,12 @@ class FunctionContext(Enum): PROPERTY = kw.keyword("property") +@attr.define +class FunctionContext: + function_type: FunctionContextType + is_generator: bool = False + + class KeywordArgSupport(Enum): APPLY_KWARGS = kw.keyword("apply") COLLECT_KWARGS = kw.keyword("collect") @@ -336,7 +342,7 @@ class NodeEnv: class Await(Node[ReaderLispForm]): form: ReaderLispForm expr: Node - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.v(EXPR) op: NodeOp = NodeOp.AWAIT top_level: bool = False @@ -348,7 +354,7 @@ class Binding(Node[sym.Symbol], Assignable): form: sym.Symbol name: str local: LocalType - env: NodeEnv + env: NodeEnv = attr.field(hash=False) tag: Optional[Node] = None arg_id: Optional[int] = None is_variadic: bool = False @@ -367,7 +373,7 @@ class Catch(Node[SpecialForm]): class_: Union["MaybeClass", "MaybeHostForm"] local: Binding body: "Do" - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.v(CLASS, LOCAL, BODY) op: NodeOp = NodeOp.CATCH top_level: bool = False @@ -380,7 +386,7 @@ class Const(Node[ReaderLispForm]): type: ConstType val: ReaderLispForm is_literal: bool - env: NodeEnv + env: NodeEnv = attr.field(hash=False) meta: NodeMeta = None children: Sequence[kw.Keyword] = vec.EMPTY op: NodeOp = NodeOp.CONST @@ -395,7 +401,7 @@ class Def(Node[SpecialForm]): var: Var init: Optional[Node] doc: Optional[str] - env: NodeEnv + env: NodeEnv = attr.field(hash=False) tag: Optional[Node] = None meta: NodeMeta = None children: Sequence[kw.Keyword] = vec.EMPTY @@ -414,7 +420,7 @@ class DefType(Node[SpecialForm]): interfaces: Iterable[DefTypeBase] fields: Iterable[Binding] members: Iterable["DefTypeMember"] - env: NodeEnv + env: NodeEnv = attr.field(hash=False) verified_abstract: bool = False artificially_abstract: IPersistentSet[DefTypeBase] = lset.EMPTY is_frozen: bool = True @@ -434,7 +440,7 @@ def python_member_names(self) -> Iterable[str]: class DefTypeMember(Node[SpecialForm]): form: SpecialForm name: str - env: NodeEnv + env: NodeEnv = attr.field(hash=False) @property def python_name(self) -> str: @@ -475,7 +481,7 @@ class DefTypeMethodArity(Node[SpecialForm]): body: "Do" this_local: Binding loop_id: LoopID - env: NodeEnv + env: NodeEnv = attr.field(hash=False) is_variadic: bool = False kwarg_support: Optional[KeywordArgSupport] = None children: Sequence[kw.Keyword] = vec.v(THIS_LOCAL, PARAMS, BODY) @@ -520,7 +526,7 @@ class Do(Node[SpecialForm]): form: SpecialForm statements: Iterable[Node] ret: Node - env: NodeEnv + env: NodeEnv = attr.field(hash=False) is_body: bool = False use_var_indirection: bool = False children: Sequence[kw.Keyword] = vec.v(STATEMENTS, RET) @@ -534,7 +540,7 @@ class Fn(Node[SpecialForm]): form: SpecialForm max_fixed_arity: int arities: IPersistentVector["FnArity"] - env: NodeEnv + env: NodeEnv = attr.field(hash=False) local: Optional[Binding] = None is_variadic: bool = False is_async: bool = False @@ -553,7 +559,7 @@ class FnArity(Node[SpecialForm]): params: Iterable[Binding] fixed_arity: int body: Do - env: NodeEnv + env: NodeEnv = attr.field(hash=False) tag: Optional[Node] = None is_variadic: bool = False children: Sequence[kw.Keyword] = vec.v(PARAMS, BODY) @@ -569,7 +575,7 @@ class HostCall(Node[SpecialForm]): target: Node args: Iterable[Node] kwargs: KeywordArgs - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.v(TARGET, ARGS) op: NodeOp = NodeOp.HOST_CALL top_level: bool = False @@ -581,7 +587,7 @@ class HostField(Node[Union[SpecialForm, sym.Symbol]], Assignable): form: Union[SpecialForm, sym.Symbol] field: str target: Node - env: NodeEnv + env: NodeEnv = attr.field(hash=False) is_assignable: bool = True children: Sequence[kw.Keyword] = vec.v(TARGET) op: NodeOp = NodeOp.HOST_FIELD @@ -594,7 +600,7 @@ class If(Node[SpecialForm]): form: SpecialForm test: Node then: Node - env: NodeEnv + env: NodeEnv = attr.field(hash=False) else_: Node children: Sequence[kw.Keyword] = vec.v(TEST, THEN, ELSE) op: NodeOp = NodeOp.IF @@ -606,7 +612,7 @@ class If(Node[SpecialForm]): class Import(Node[SpecialForm]): form: SpecialForm aliases: Iterable["ImportAlias"] - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.EMPTY op: NodeOp = NodeOp.IMPORT top_level: bool = False @@ -618,7 +624,7 @@ class ImportAlias(Node[Union[sym.Symbol, vec.PersistentVector]]): form: Union[sym.Symbol, vec.PersistentVector] name: str alias: Optional[str] - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.EMPTY op: NodeOp = NodeOp.IMPORT_ALIAS top_level: bool = False @@ -631,7 +637,7 @@ class Invoke(Node[SpecialForm]): fn: Node args: Iterable[Node] kwargs: KeywordArgs - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.v(FN, ARGS) op: NodeOp = NodeOp.INVOKE top_level: bool = False @@ -643,7 +649,7 @@ class Let(Node[SpecialForm]): form: SpecialForm bindings: Iterable[Binding] body: Do - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.v(BINDINGS, BODY) op: NodeOp = NodeOp.LET top_level: bool = False @@ -655,7 +661,7 @@ class LetFn(Node[SpecialForm]): form: SpecialForm bindings: Iterable[Binding] body: Do - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.v(BINDINGS, BODY) op: NodeOp = NodeOp.LETFN top_level: bool = False @@ -667,7 +673,7 @@ class Local(Node[sym.Symbol], Assignable): form: sym.Symbol name: str local: LocalType - env: NodeEnv + env: NodeEnv = attr.field(hash=False) is_assignable: bool = False arg_id: Optional[int] = None is_variadic: bool = False @@ -683,7 +689,7 @@ class Loop(Node[SpecialForm]): bindings: Iterable[Binding] body: Do loop_id: LoopID - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.v(BINDINGS, BODY) op: NodeOp = NodeOp.LOOP top_level: bool = False @@ -695,7 +701,7 @@ class Map(Node[IPersistentMap]): form: IPersistentMap keys: Iterable[Node] vals: Iterable[Node] - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.v(KEYS, VALS) op: NodeOp = NodeOp.MAP top_level: bool = False @@ -707,7 +713,7 @@ class MaybeClass(Node[sym.Symbol]): form: sym.Symbol class_: str target: Any - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.EMPTY op: NodeOp = NodeOp.MAYBE_CLASS top_level: bool = False @@ -720,7 +726,7 @@ class MaybeHostForm(Node[sym.Symbol]): class_: str field: str target: Any - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.EMPTY op: NodeOp = NodeOp.MAYBE_HOST_FORM top_level: bool = False @@ -732,7 +738,7 @@ class PyDict(Node[dict]): form: dict keys: Iterable[Node] vals: Iterable[Node] - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.v(KEYS, VALS) op: NodeOp = NodeOp.PY_DICT top_level: bool = False @@ -743,7 +749,7 @@ class PyDict(Node[dict]): class PyList(Node[list]): form: list items: Iterable[Node] - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.v(ITEMS) op: NodeOp = NodeOp.PY_LIST top_level: bool = False @@ -754,7 +760,7 @@ class PyList(Node[list]): class PySet(Node[Union[frozenset, set]]): form: Union[frozenset, set] items: Iterable[Node] - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.v(ITEMS) op: NodeOp = NodeOp.PY_SET top_level: bool = False @@ -765,7 +771,7 @@ class PySet(Node[Union[frozenset, set]]): class PyTuple(Node[tuple]): form: tuple items: Iterable[Node] - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.v(ITEMS) op: NodeOp = NodeOp.PY_TUPLE top_level: bool = False @@ -776,7 +782,7 @@ class PyTuple(Node[tuple]): class Queue(Node[lqueue.PersistentQueue]): form: lqueue.PersistentQueue items: Iterable[Node] - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.v(ITEMS) op: NodeOp = NodeOp.QUEUE top_level: bool = False @@ -787,7 +793,7 @@ class Queue(Node[lqueue.PersistentQueue]): class Quote(Node[SpecialForm]): form: SpecialForm expr: Const - env: NodeEnv + env: NodeEnv = attr.field(hash=False) is_literal: bool = True children: Sequence[kw.Keyword] = vec.v(EXPR) op: NodeOp = NodeOp.QUOTE @@ -800,7 +806,7 @@ class Recur(Node[SpecialForm]): form: SpecialForm exprs: Iterable[Node] loop_id: LoopID - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.v(EXPRS) op: NodeOp = NodeOp.RECUR top_level: bool = False @@ -812,7 +818,7 @@ class Reify(Node[SpecialForm]): form: SpecialForm interfaces: Iterable[DefTypeBase] members: Iterable["DefTypeMember"] - env: NodeEnv + env: NodeEnv = attr.field(hash=False) verified_abstract: bool = False artificially_abstract: IPersistentSet[DefTypeBase] = lset.EMPTY is_frozen: bool = True @@ -833,7 +839,7 @@ class RequireAlias(Node[Union[sym.Symbol, vec.PersistentVector]]): form: Union[sym.Symbol, vec.PersistentVector] name: str alias: Optional[str] - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.EMPTY op: NodeOp = NodeOp.REQUIRE_ALIAS top_level: bool = False @@ -844,7 +850,7 @@ class RequireAlias(Node[Union[sym.Symbol, vec.PersistentVector]]): class Require(Node[SpecialForm]): form: SpecialForm aliases: Iterable[RequireAlias] - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.EMPTY op: NodeOp = NodeOp.REQUIRE top_level: bool = False @@ -855,7 +861,7 @@ class Require(Node[SpecialForm]): class Set(Node[IPersistentSet]): form: IPersistentSet items: Iterable[Node] - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.v(ITEMS) op: NodeOp = NodeOp.SET top_level: bool = False @@ -867,7 +873,7 @@ class SetBang(Node[SpecialForm]): form: SpecialForm target: Union[Assignable, Node] val: Node - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.v(TARGET, VAL) op: NodeOp = NodeOp.SET_BANG top_level: bool = False @@ -879,7 +885,7 @@ class Throw(Node[SpecialForm]): form: SpecialForm exception: Node cause: Optional[Node] - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.v(EXCEPTION) op: NodeOp = NodeOp.THROW top_level: bool = False @@ -892,7 +898,7 @@ class Try(Node[SpecialForm]): body: Do catches: Iterable[Catch] children: Sequence[kw.Keyword] - env: NodeEnv + env: NodeEnv = attr.field(hash=False) finally_: Optional[Do] = None op: NodeOp = NodeOp.TRY top_level: bool = False @@ -903,7 +909,7 @@ class Try(Node[SpecialForm]): class VarRef(Node[Union[sym.Symbol, ISeq]], Assignable): form: Union[sym.Symbol, ISeq] var: Var - env: NodeEnv + env: NodeEnv = attr.field(hash=False) return_var: bool = False is_assignable: bool = True is_allow_var_indirection: bool = False @@ -917,7 +923,7 @@ class VarRef(Node[Union[sym.Symbol, ISeq]], Assignable): class Vector(Node[IPersistentVector]): form: IPersistentVector items: Iterable[Node] - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.v(ITEMS) op: NodeOp = NodeOp.VECTOR top_level: bool = False @@ -932,7 +938,7 @@ class WithMeta(Node[LispForm], Generic[T_withmeta]): form: LispForm meta: Union[Const, Map] expr: T_withmeta - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.v(META, EXPR) op: NodeOp = NodeOp.WITH_META top_level: bool = False @@ -943,7 +949,7 @@ class WithMeta(Node[LispForm], Generic[T_withmeta]): class Yield(Node[SpecialForm]): form: SpecialForm expr: Optional[Node] - env: NodeEnv + env: NodeEnv = attr.field(hash=False) children: Sequence[kw.Keyword] = vec.v(EXPR) op: NodeOp = NodeOp.YIELD top_level: bool = False From 98c25a53c8cdf2b0559ddf68f5c1711775fa3ea9 Mon Sep 17 00:00:00 2001 From: Chris Rink Date: Wed, 9 Apr 2025 21:29:22 -0400 Subject: [PATCH 2/2] Test --- src/basilisp/lang/compiler/analyzer.py | 12 --------- src/basilisp/lang/compiler/generator.py | 2 +- tests/basilisp/compiler_test.py | 36 +++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 13 deletions(-) diff --git a/src/basilisp/lang/compiler/analyzer.py b/src/basilisp/lang/compiler/analyzer.py index b7c4e3e3e..436a56572 100644 --- a/src/basilisp/lang/compiler/analyzer.py +++ b/src/basilisp/lang/compiler/analyzer.py @@ -1244,8 +1244,6 @@ def __deftype_classmethod( # pylint: disable=too-many-locals statements=vec.vector(stmts), ret=ret, is_body=True, - # Use the argument vector or first body statement, whichever - # exists, for metadata. env=ctx.get_node_env(), ) method = DefTypeClassMethod( @@ -1309,8 +1307,6 @@ def __deftype_or_reify_method( # pylint: disable=too-many-arguments,too-many-lo statements=vec.vector(stmts), ret=ret, is_body=True, - # Use the argument vector or first body statement, whichever - # exists, for metadata. env=ctx.get_node_env(), ) method = DefTypeMethodArity( @@ -1380,8 +1376,6 @@ def __deftype_or_reify_property( statements=vec.vector(stmts), ret=ret, is_body=True, - # Use the argument vector or first body statement, whichever - # exists, for metadata. env=ctx.get_node_env(), ) prop = DefTypeProperty( @@ -1418,8 +1412,6 @@ def __deftype_staticmethod( statements=vec.vector(stmts), ret=ret, is_body=True, - # Use the argument vector or first body statement, whichever - # exists, for metadata. env=ctx.get_node_env(), ) method = DefTypeStaticMethod( @@ -2123,8 +2115,6 @@ def __fn_method_ast( # pylint: disable=too-many-locals statements=vec.vector(stmts), ret=ret, is_body=True, - # Use the argument vector or first body statement, whichever - # exists, for metadata. env=ctx.get_node_env(), ) method = FnArity( @@ -2135,8 +2125,6 @@ def __fn_method_ast( # pylint: disable=too-many-locals is_variadic=has_vargs, fixed_arity=len(param_nodes) - int(has_vargs), body=body, - # Use the argument vector for fetching line/col since the - # form itself is a sequence with no meaningful metadata. env=ctx.get_node_env(), ) method.visit(partial(_assert_recur_is_tail, ctx)) diff --git a/src/basilisp/lang/compiler/generator.py b/src/basilisp/lang/compiler/generator.py index f14a01c39..f5f68b379 100644 --- a/src/basilisp/lang/compiler/generator.py +++ b/src/basilisp/lang/compiler/generator.py @@ -1723,7 +1723,7 @@ def __fn_args_to_py_ast( and func_ctx.is_generator and func_ctx.function_type == FunctionContextType.ASYNC_FUNCTION ): - fn_body_ast.append(ast.Pass()) + fn_body_ast.append(statementize(body_ast.node)) else: fn_body_ast.append(ast.Return(value=body_ast.node)) diff --git a/tests/basilisp/compiler_test.py b/tests/basilisp/compiler_test.py index 6e72db39b..eb8f2b8f3 100644 --- a/tests/basilisp/compiler_test.py +++ b/tests/basilisp/compiler_test.py @@ -3025,6 +3025,42 @@ def test_async_multi_arity(self, lcompile: CompileFn): kw.keyword("await-result-0"), kw.keyword("await-result-1") ) == async_to_sync(awaiter) + def test_async_generator_single_statement(self, lcompile: CompileFn): + awaiter_fn: runtime.Var = lcompile( + """ + (fn ^:async unique-kdghii + [] + (yield :async-generator)) + """ + ) + + async def get(): + async for val in awaiter_fn(): + return val + + assert kw.keyword("async-generator") == async_to_sync(get) + + def test_async_generator_multi_statement(self, lcompile: CompileFn): + awaiter_fn: runtime.Var = lcompile( + """ + (fn ^:async unique-kdghii + [] + (yield :async-generator-1) + (yield :async-generator-2)) + """ + ) + + async def get(): + vals = [] + async for val in awaiter_fn(): + vals.append(val) + return vals + + assert [ + kw.keyword("async-generator-1"), + kw.keyword("async-generator-2"), + ] == async_to_sync(get) + def test_fn_with_meta_must_be_map(self, lcompile: CompileFn): f = lcompile("^:meta-kw (fn* [] :super-unique-kw)") with pytest.raises(TypeError):