Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions libcst/_nodes/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,14 +712,30 @@ def _codegen_impl(self, state: CodegenState) -> None:
state.increase_indent(state.default_indent if indent is None else indent)

if self.body:
with state.record_syntactic_position(
self, start_node=self.body[0], end_node=self.body[-1]
first_statement = self.body[0]
if (
isinstance(first_statement, (FunctionDef, ClassDef))
and first_statement.decorators
):
for stmt in self.body:
# IndentedBlock is responsible for adjusting the current indentation level,
# but its children are responsible for actually adding that indentation to
# the token list.
stmt._codegen(state)
# If the first statement is a function or class definition, we need to
# use the position of the first decorator instead of the function/class definition.
with state.record_syntactic_position(
self, start_node=first_statement.decorators[0], end_node=self.body[-1]
):
for stmt in self.body:
# IndentedBlock is responsible for adjusting the current indentation level,
# but its children are responsible for actually adding that indentation to
# the token list.
stmt._codegen(state)
else:
with state.record_syntactic_position(
self, start_node=first_statement, end_node=self.body[-1]
):
for stmt in self.body:
# IndentedBlock is responsible for adjusting the current indentation level,
# but its children are responsible for actually adding that indentation to
# the token list.
stmt._codegen(state)
else:
# Empty indented blocks are not syntactically valid in Python unless
# they contain a 'pass' statement, so add one here.
Expand Down
56 changes: 56 additions & 0 deletions libcst/metadata/tests/test_position_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,62 @@ def visit_Pass(self, node: cst.Pass) -> None:
wrapper = MetadataWrapper(parse_module("pass"))
wrapper.visit_batched([ABatchable()])

def test_indented_block_starting_with_decorated_function_def(self) -> None:
"""
Tests that the position provider correctly computes positions in an indented block
starting with a decorated function definition.
"""
test = self

class IndentedBlockVisitor(CSTVisitor):
METADATA_DEPENDENCIES = (PositionProvider,)

def visit_IndentedBlock(self, node: cst.IndentedBlock) -> None:
test.assertEqual(
self.get_metadata(PositionProvider, node),
CodeRange((3, 4), (5, 15)),
)

wrapper = MetadataWrapper(
parse_module(
""" # Empty line
def foo():
@decorator
def func(): return 42
return func
"""
)
)
wrapper.visit(IndentedBlockVisitor())

def test_indented_block_starting_with_decorated_class_def(self) -> None:
"""
Tests that the position provider correctly computes positions in an indented block
starting with a decorated class definition.
"""
test = self

class IndentedBlockVisitor(CSTVisitor):
METADATA_DEPENDENCIES = (PositionProvider,)

def visit_IndentedBlock(self, node: cst.IndentedBlock) -> None:
test.assertEqual(
self.get_metadata(PositionProvider, node),
CodeRange((3, 4), (5, 18)),
)

wrapper = MetadataWrapper(
parse_module(
""" # Empty line
def foo():
@decorator
class MyClass: pass
return MyClass
"""
)
)
wrapper.visit(IndentedBlockVisitor())

def test_match_statement_position_metadata(self) -> None:
test = self

Expand Down