diff --git a/libcst/_nodes/statement.py b/libcst/_nodes/statement.py index cdc49edc..4705fb72 100644 --- a/libcst/_nodes/statement.py +++ b/libcst/_nodes/statement.py @@ -712,14 +712,30 @@ def _codegen_impl(self, state: CodegenState) -> None: state.increase_indent(state.default_indent if indent is None else indent) if self.body: - with state.record_syntactic_position( - self, start_node=self.body[0], end_node=self.body[-1] + first_statement = self.body[0] + if ( + isinstance(first_statement, (FunctionDef, ClassDef)) + and first_statement.decorators ): - for stmt in self.body: - # IndentedBlock is responsible for adjusting the current indentation level, - # but its children are responsible for actually adding that indentation to - # the token list. - stmt._codegen(state) + # If the first statement is a function or class definition, we need to + # use the position of the first decorator instead of the function/class definition. + with state.record_syntactic_position( + self, start_node=first_statement.decorators[0], end_node=self.body[-1] + ): + for stmt in self.body: + # IndentedBlock is responsible for adjusting the current indentation level, + # but its children are responsible for actually adding that indentation to + # the token list. + stmt._codegen(state) + else: + with state.record_syntactic_position( + self, start_node=first_statement, end_node=self.body[-1] + ): + for stmt in self.body: + # IndentedBlock is responsible for adjusting the current indentation level, + # but its children are responsible for actually adding that indentation to + # the token list. + stmt._codegen(state) else: # Empty indented blocks are not syntactically valid in Python unless # they contain a 'pass' statement, so add one here. diff --git a/libcst/metadata/tests/test_position_provider.py b/libcst/metadata/tests/test_position_provider.py index 14cecec7..79a07c37 100644 --- a/libcst/metadata/tests/test_position_provider.py +++ b/libcst/metadata/tests/test_position_provider.py @@ -83,6 +83,62 @@ def visit_Pass(self, node: cst.Pass) -> None: wrapper = MetadataWrapper(parse_module("pass")) wrapper.visit_batched([ABatchable()]) + def test_indented_block_starting_with_decorated_function_def(self) -> None: + """ + Tests that the position provider correctly computes positions in an indented block + starting with a decorated function definition. + """ + test = self + + class IndentedBlockVisitor(CSTVisitor): + METADATA_DEPENDENCIES = (PositionProvider,) + + def visit_IndentedBlock(self, node: cst.IndentedBlock) -> None: + test.assertEqual( + self.get_metadata(PositionProvider, node), + CodeRange((3, 4), (5, 15)), + ) + + wrapper = MetadataWrapper( + parse_module( + """ # Empty line +def foo(): + @decorator + def func(): return 42 + return func +""" + ) + ) + wrapper.visit(IndentedBlockVisitor()) + + def test_indented_block_starting_with_decorated_class_def(self) -> None: + """ + Tests that the position provider correctly computes positions in an indented block + starting with a decorated class definition. + """ + test = self + + class IndentedBlockVisitor(CSTVisitor): + METADATA_DEPENDENCIES = (PositionProvider,) + + def visit_IndentedBlock(self, node: cst.IndentedBlock) -> None: + test.assertEqual( + self.get_metadata(PositionProvider, node), + CodeRange((3, 4), (5, 18)), + ) + + wrapper = MetadataWrapper( + parse_module( + """ # Empty line +def foo(): + @decorator + class MyClass: pass + return MyClass +""" + ) + ) + wrapper.visit(IndentedBlockVisitor()) + def test_match_statement_position_metadata(self) -> None: test = self