diff --git a/libcst/_nodes/statement.py b/libcst/_nodes/statement.py index 1aba38d3..cdc49edc 100644 --- a/libcst/_nodes/statement.py +++ b/libcst/_nodes/statement.py @@ -2886,6 +2886,9 @@ def _codegen_impl(self, state: CodegenState) -> None: state.add_token("if") self.whitespace_after_if._codegen(state) guard._codegen(state) + else: + self.whitespace_before_if._codegen(state) + self.whitespace_after_if._codegen(state) self.whitespace_before_colon._codegen(state) state.add_token(":") @@ -3473,6 +3476,13 @@ def _codegen_impl(self, state: CodegenState) -> None: state.add_token(" ") elif isinstance(ws_after, BaseParenthesizableWhitespace): ws_after._codegen(state) + else: + ws_before = self.whitespace_before_as + if isinstance(ws_before, BaseParenthesizableWhitespace): + ws_before._codegen(state) + ws_after = self.whitespace_after_as + if isinstance(ws_after, BaseParenthesizableWhitespace): + ws_after._codegen(state) if name is None: state.add_token("_") else: diff --git a/libcst/metadata/tests/test_position_provider.py b/libcst/metadata/tests/test_position_provider.py index c479837e..14cecec7 100644 --- a/libcst/metadata/tests/test_position_provider.py +++ b/libcst/metadata/tests/test_position_provider.py @@ -83,6 +83,53 @@ def visit_Pass(self, node: cst.Pass) -> None: wrapper = MetadataWrapper(parse_module("pass")) wrapper.visit_batched([ABatchable()]) + def test_match_statement_position_metadata(self) -> None: + test = self + + class MatchPositionVisitor(CSTVisitor): + METADATA_DEPENDENCIES = (PositionProvider,) + + def visit_Match(self, node: cst.Match) -> None: + test.assertEqual( + self.get_metadata(PositionProvider, node), + CodeRange((2, 0), (5, 16)), + ) + + def visit_MatchCase(self, node: cst.MatchCase) -> None: + if ( + isinstance(node.pattern, cst.MatchAs) + and node.pattern.name + and node.pattern.name.value == "b" + ): + test.assertEqual( + self.get_metadata(PositionProvider, node), + CodeRange((3, 4), (3, 16)), + ) + elif ( + isinstance(node.pattern, cst.MatchAs) + and node.pattern.name + and node.pattern.name.value == "c" + ): + test.assertEqual( + self.get_metadata(PositionProvider, node), + CodeRange((4, 4), (4, 16)), + ) + elif isinstance(node.pattern, cst.MatchAs) and not node.pattern.name: + test.assertEqual( + self.get_metadata(PositionProvider, node), + CodeRange((5, 4), (5, 16)), + ) + + code = """ +match status: + case b: pass + case c: pass + case _: pass +""" + + wrapper = MetadataWrapper(parse_module(code)) + wrapper.visit(MatchPositionVisitor()) + class PositionProvidingCodegenStateTest(UnitTest): def test_codegen_initial_position(self) -> None: