Skip to content
This repository was archived by the owner on Jan 13, 2026. It is now read-only.

Commit 0d32702

Browse files
committed
Support appending statements to blocks using templating
1 parent 56e73ba commit 0d32702

File tree

7 files changed

+161
-15
lines changed

7 files changed

+161
-15
lines changed

rewrite/rewrite/java/support_types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,12 +335,24 @@ def replace(self, loc: Optional[Space.Location] = None) -> JavaCoordinates:
335335

336336
@dataclass
337337
class _ExpressionCoordinateBuilder(CoordinateBuilder):
338+
def after(self, loc: Optional[Space.Location] = None) -> JavaCoordinates:
339+
return JavaCoordinates(self.tree, loc or Space.Location.EXPRESSION_PREFIX, JavaCoordinates.Mode.AFTER)
340+
341+
def before(self, loc: Optional[Space.Location] = None) -> JavaCoordinates:
342+
return JavaCoordinates(self.tree, loc or Space.Location.EXPRESSION_PREFIX, JavaCoordinates.Mode.BEFORE)
343+
338344
def replace(self, loc: Optional[Space.Location] = None) -> JavaCoordinates:
339345
return JavaCoordinates(self.tree, loc or Space.Location.EXPRESSION_PREFIX, JavaCoordinates.Mode.REPLACE)
340346

341347

342348
@dataclass
343349
class _StatementCoordinateBuilder(CoordinateBuilder):
350+
def after(self, loc: Optional[Space.Location] = None) -> JavaCoordinates:
351+
return JavaCoordinates(self.tree, loc or Space.Location.STATEMENT_PREFIX, JavaCoordinates.Mode.AFTER)
352+
353+
def before(self, loc: Optional[Space.Location] = None) -> JavaCoordinates:
354+
return JavaCoordinates(self.tree, loc or Space.Location.STATEMENT_PREFIX, JavaCoordinates.Mode.BEFORE)
355+
344356
def replace(self, loc: Optional[Space.Location] = None) -> JavaCoordinates:
345357
return JavaCoordinates(self.tree, loc or Space.Location.STATEMENT_PREFIX, JavaCoordinates.Mode.REPLACE)
346358

rewrite/rewrite/python/format/auto_format.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def get_visitor(self):
2121
return AutoFormatVisitor()
2222

2323

24-
class AutoFormatVisitor(PythonVisitor):
24+
class AutoFormatVisitor(PythonVisitor[P]):
2525
def __init__(self, stop_after: Optional[Tree] = None):
2626
self._stop_after = stop_after
2727

rewrite/rewrite/python/templating.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
import re
44
from dataclasses import dataclass, field
5-
from typing import Any, Callable, Optional, cast, List
5+
from typing import Any, Callable, Optional, cast, List, Type
66

77
from . import CompilationUnit, ExpressionStatement, PythonVisitor
88
from .parser import PythonParserBuilder
99
from .printer import PythonPrinter
10-
from .. import PrintOutputCapture
11-
from ..java import J, JavaCoordinates, Space, Identifier, P, Literal
10+
from .. import PrintOutputCapture, Tree
11+
from ..java import J, JavaCoordinates, Space, Identifier, P, Literal, Expression, Statement, Block
1212
from ..parser import ParserBuilder
13+
from ..utils import list_flat_map
1314
from ..visitor import Cursor
1415

1516

@@ -30,8 +31,9 @@ def apply(self, scope: Cursor, coordinates: JavaCoordinates, *parameters) -> J:
3031
substituted = substitutions.substitute()
3132
if self.on_after_variable_substitution:
3233
self.on_after_variable_substitution(substituted)
33-
parsed = self._template_parser.parse_expression(scope, substituted, coordinates.loc)
34-
return substitutions.unsubstitute(parsed).with_prefix(cast(J, scope.value).prefix)
34+
35+
return PythonTemplatePythonExtension(self._template_parser, substitutions, substituted, coordinates)\
36+
.visit(cast(Tree, scope.value), 0, scope.parent_or_throw)
3537

3638
def substitutions(self, parameters: List[Any]) -> Substitutions:
3739
return Substitutions(self.code, parameters)
@@ -50,6 +52,10 @@ def parse_expression(self, scope: Cursor, template: str, loc: Space.Location) ->
5052
cu.statements[0]
5153
return j.with_prefix(cast(J, scope.value).prefix)
5254

55+
def parse_block_statements(self, cursor: Cursor, expected: Type, template: str, loc: Space.Location, mode: JavaCoordinates.Mode) -> List[J]:
56+
cu: CompilationUnit = next(iter(self.parser_builder.build().parse_strings(template)))
57+
return cu.statements
58+
5359

5460
@dataclass
5561
class Substitutions:
@@ -110,6 +116,9 @@ def substitute_untyped(self, index: int) -> str:
110116
def unsubstitute(self, parsed: J) -> J:
111117
return cast(J, UnsubstitutionVisitor(self.parameters).visit(parsed, 0))
112118

119+
def unsubstitute_all(self, parsed: List[J]) -> List[J]:
120+
return [self.unsubstitute(j) for j in parsed]
121+
113122

114123
@dataclass
115124
class UnsubstitutionVisitor(PythonVisitor[int]):
@@ -120,3 +129,46 @@ def visit_identifier(self, identifier: Identifier, p: P) -> J:
120129
if match := self._param_pattern.fullmatch(identifier.simple_name):
121130
return cast(J, self.parameters[int(match.group(1))]).with_prefix(identifier.prefix)
122131
return identifier
132+
133+
@dataclass
134+
class PythonTemplatePythonExtension(PythonVisitor[int]):
135+
template_parser: PythonTemplateParser
136+
substitutions: Substitutions
137+
substituted_template: str
138+
coordinates: JavaCoordinates
139+
140+
def __post_init__(self):
141+
self.insertion_point = self.coordinates.tree
142+
self.loc = self.coordinates.loc
143+
self.mode = self.coordinates.mode
144+
145+
def visit_expression(self, expression: Expression, p: int) -> J:
146+
if (self.loc == Space.Location.EXPRESSION_PREFIX or self.loc == Space.Location.STATEMENT_PREFIX and
147+
isinstance(expression, Statement)) and expression.is_scope(self.insertion_point):
148+
parsed = self.template_parser.parse_expression(self.cursor, self.substituted_template, self.loc)
149+
return self.auto_format(self.substitutions.unsubstitute(parsed).with_prefix(expression.prefix), p)
150+
return expression
151+
152+
def visit_block(self, block: Block, p: P) -> J:
153+
if self.loc == Space.Location.STATEMENT_PREFIX:
154+
return self.auto_format(block.with_statements(list_flat_map(lambda s: self.get_replacements(s) if s.is_scope(self.insertion_point) else s, block.statements)), p, self.cursor.parent)
155+
return super().visit_block(block, p)
156+
157+
def visit_statement(self, statement: Statement, p: P) -> J:
158+
return statement
159+
# if (self.loc == Space.Location.STATEMENT_PREFIX and statement.is_scope(self.insertion_point):
160+
# parsed = self.template_parser.parse_expression(self.cursor, self.substituted_template, self.loc)
161+
# return self.auto_format(self.substitutions.unsubstitute(parsed).with_prefix(expression.prefix), p)
162+
# return expression
163+
164+
def get_replacements(self, statement: Statement) -> List[J]:
165+
parsed = self.template_parser.parse_block_statements(Cursor(self.cursor, self.insertion_point), Statement, self.substituted_template, self.loc, self.mode)
166+
gen = self.substitutions.unsubstitute_all(parsed)
167+
formatted = [s.with_prefix(statement.prefix.with_comments([])) for s in gen]
168+
169+
if self.mode == JavaCoordinates.Mode.REPLACE:
170+
return formatted
171+
elif self.mode == JavaCoordinates.Mode.BEFORE:
172+
return formatted + [statement]
173+
else:
174+
return [statement] + formatted

rewrite/rewrite/python/visitor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ class PythonVisitor(JavaVisitor[P]):
1111
def is_acceptable(self, source_file: SourceFile, p: P) -> bool:
1212
return isinstance(source_file, Py)
1313

14+
def auto_format(self, j: J, p: P, cursor: Optional[Cursor] = None, stop_after: Optional[J] = None) -> J:
15+
cursor = cursor or self.cursor.parent_tree_cursor()
16+
from .format import AutoFormatVisitor
17+
return AutoFormatVisitor(stop_after).visit(j, p, cursor)
18+
1419
def visit_async(self, async_: Async, p: P) -> J:
1520
async_ = async_.with_prefix(self.visit_space(async_.prefix, PySpace.Location.ASYNC_PREFIX, p))
1621
temp_statement = cast(Statement, self.visit_statement(async_, p))

rewrite/rewrite/tree.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ def print(self, cursor: 'Cursor', capture: 'PrintOutputCapture[P]') -> str:
5454
def printer(self, cursor: 'Cursor') -> 'TreeVisitor[Any, PrintOutputCapture[P]]':
5555
return cursor.first_enclosing_or_throw(SourceFile).printer(cursor)
5656

57+
def is_scope(self, tree: Optional[Tree]) -> bool:
58+
return tree and tree.id == self.id
59+
5760
def __eq__(self, other: object) -> bool:
5861
if self.__class__ == other.__class__:
5962
return self.id == cast(Tree, other).id

rewrite/rewrite/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def random_id() -> UUID:
1010

1111
# Define a type that allows both single and two-argument callables
1212
FnType = Union[Callable[[T], Union[T, None]], Callable[[T, int], Union[T, None]]]
13+
FlatMapFnType = Union[Callable[[T], List[T]], Callable[[T, int], List[T]]]
1314

1415
def list_map(fn: FnType[T], lst: List[T]) -> List[T]:
1516
changed = False
@@ -33,6 +34,24 @@ def list_map(fn: FnType[T], lst: List[T]) -> List[T]:
3334
return mapped_lst if changed else lst # type: ignore
3435

3536

37+
def list_flat_map(fn: FlatMapFnType[T], lst: List[T]) -> List[T]:
38+
changed = False
39+
result: List[T] = []
40+
41+
with_index = len(inspect.signature(fn).parameters) == 2
42+
for index, item in enumerate(lst):
43+
new_items = fn(item, index) if with_index else fn(item) # type: ignore
44+
if new_items is None:
45+
changed = True
46+
continue
47+
48+
if len(new_items) != 1 or new_items[0] is not item:
49+
changed = True
50+
result.extend(new_items)
51+
52+
return result if changed else lst
53+
54+
3655
def list_map_last(fn: Callable[[T], Union[T, None]], lst: List[T]) -> List[T]:
3756
if not lst:
3857
return lst

rewrite/tests/python/all/templating/template_test.py

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dataclasses import dataclass, field
2-
from typing import Any, Callable, List, cast
2+
from typing import Any, Callable, List, cast, Union
33

4-
from rewrite.java import Literal, P, J, Expression
4+
from rewrite.java import Literal, P, J, Expression, Statement, JavaCoordinates, MethodDeclaration
55
from rewrite.python import PythonVisitor, PythonTemplate, PythonParserBuilder, CompilationUnit, ExpressionStatement
66
from rewrite.test import from_visitor, RecipeSpec, rewrite_run, python
77

@@ -15,7 +15,7 @@ def test_string_substitution():
1515
),
1616
spec=RecipeSpec()
1717
.with_recipe(from_visitor(
18-
ExpressionTemplatingVisitor(lambda j: isinstance(j, Literal), '#{}', [2])))
18+
ReplaceTemplatingVisitor(lambda j: isinstance(j, Literal), '#{}', [2])))
1919
)
2020

2121

@@ -28,7 +28,31 @@ def test_tree_substitution():
2828
),
2929
spec=RecipeSpec()
3030
.with_recipe(from_visitor(
31-
ExpressionTemplatingVisitor(lambda j: isinstance(j, Literal), '#{any()}', [parse_expression('2')])))
31+
ReplaceTemplatingVisitor(lambda j: isinstance(j, Literal), '#{any()}', [parse_expression('2')])))
32+
)
33+
34+
35+
def test_add_statement():
36+
rewrite_run(
37+
# language=python
38+
python(
39+
"""\
40+
def f():
41+
pass
42+
""",
43+
"""\
44+
def f():
45+
pass
46+
return
47+
"""
48+
),
49+
spec=RecipeSpec()
50+
.with_recipe(from_visitor(
51+
AddLastTemplatingVisitor(
52+
lambda j: isinstance(j, MethodDeclaration) and len(j.body.statements) == 1,
53+
'return',
54+
coordinate_provider=lambda m: cast(MethodDeclaration, m).body.statements[0].get_coordinates().after())
55+
))
3256
)
3357

3458

@@ -41,18 +65,20 @@ def test_tree_substitution_named():
4165
),
4266
spec=RecipeSpec()
4367
.with_recipe(from_visitor(
44-
ExpressionTemplatingVisitor(lambda j: isinstance(j, Literal), '#{name:any()} + #{any()} + #{name}', [parse_expression('2'), parse_expression('3')])))
68+
ReplaceTemplatingVisitor(lambda j: isinstance(j, Literal), '#{name:any()} + #{any()} + #{name}', [parse_expression('2'), parse_expression('3')])))
4569
)
4670

4771

4872
def parse_expression(code: str) -> J:
49-
return cast(ExpressionStatement,
50-
cast(CompilationUnit, next(iter(PythonParserBuilder().build().parse_strings(code)))).statements[
51-
0]).expression
73+
return cast(ExpressionStatement, parse_statement(code)).expression
74+
75+
76+
def parse_statement(code: str) -> J:
77+
return cast(CompilationUnit, next(iter(PythonParserBuilder().build().parse_strings(code)))).statements[0]
5278

5379

5480
@dataclass
55-
class ExpressionTemplatingVisitor(PythonVisitor[P]):
81+
class ReplaceTemplatingVisitor(PythonVisitor[P]):
5682
match: Callable[[J], bool]
5783
code: str
5884
params: List[Any] = field(default_factory=list)
@@ -68,3 +94,32 @@ def __post_init__(self):
6894
def visit_expression(self, expr: Expression, p: P) -> J:
6995
return self._template.apply(self.cursor, expr.get_coordinates().replace(), self.params) if self.match(
7096
expr) else expr
97+
98+
def visit_statement(self, stmt: Statement, p: P) -> J:
99+
return self._template.apply(self.cursor, stmt.get_coordinates().replace(), self.params) if self.match(
100+
stmt) else stmt
101+
102+
103+
@dataclass
104+
class AddLastTemplatingVisitor(PythonVisitor[P]):
105+
match: Callable[[J], bool]
106+
code: str
107+
coordinate_provider: Callable[[Union[Expression, Statement]], JavaCoordinates] = lambda j: j.get_coordinates().after()
108+
params: List[Any] = field(default_factory=list)
109+
110+
debug: bool = False
111+
_template: PythonTemplate = field(init=False, repr=False)
112+
113+
def __post_init__(self):
114+
self._template = PythonTemplate(
115+
self.code,
116+
on_after_variable_substitution=lambda code: print(code) if self.debug else None
117+
)
118+
119+
def visit_expression(self, expr: Expression, p: P) -> J:
120+
return self._template.apply(self.cursor, self.coordinate_provider(expr), self.params) if self.match(
121+
expr) else expr
122+
123+
def visit_statement(self, stmt: Statement, p: P) -> J:
124+
return self._template.apply(self.cursor, self.coordinate_provider(stmt), self.params) if self.match(
125+
stmt) else stmt

0 commit comments

Comments
 (0)