Skip to content
This repository was archived by the owner on Jan 13, 2026. It is now read-only.

Commit fbfcb5c

Browse files
Add CoordinateBuilder.Block.last_statement() coordinate (#138)
This allows adding a statement as the last statement to a block. Adding this required adding a new `MinimumViableSpacingVisitor` to add a linebreak prefix to any statements that don't already have one and where the preceding statement (if any) doesn't have a terminating semicolon.
1 parent 3f78c08 commit fbfcb5c

File tree

13 files changed

+220
-34
lines changed

13 files changed

+220
-34
lines changed

rewrite/rewrite/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .execution import ExecutionContext, DelegatingExecutionContext, InMemoryExecutionContext, Recipe, RecipeRunException
88
from .markers import *
99
from .tree import Checksum, FileAttributes, SourceFile, Tree, PrintOutputCapture, PrinterFactory
10-
from .utils import random_id, list_map, list_map_last
10+
from .utils import random_id, list_find, list_map, list_map_last
1111
from .visitor import Cursor, TreeVisitor
1212
from .parser import *
1313
from .result import *
@@ -21,6 +21,7 @@
2121
'PrintOutputCapture',
2222
'PrinterFactory',
2323
'random_id',
24+
'list_find',
2425
'list_map',
2526
'list_map_last',
2627
'Cursor',

rewrite/rewrite/java/support_types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,15 @@ def replace(self, loc: Optional[Space.Location] = None) -> JavaCoordinates:
357357
return JavaCoordinates(self.tree, loc or Space.Location.STATEMENT_PREFIX, JavaCoordinates.Mode.REPLACE)
358358

359359

360+
@dataclass
361+
class _BlockCoordinateBuilder(_StatementCoordinateBuilder):
362+
def last_statement(self) -> JavaCoordinates:
363+
return self.before(Space.Location.BLOCK_END)
364+
365+
360366
CoordinateBuilder.Expression = _ExpressionCoordinateBuilder # type: ignore
361367
CoordinateBuilder.Statement = _StatementCoordinateBuilder # type: ignore
368+
CoordinateBuilder.Block = _BlockCoordinateBuilder # type: ignore
362369

363370

364371
@dataclass

rewrite/rewrite/java/tree.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,9 @@ def padding(self) -> PaddingHelper:
733733
def accept_java(self, v: JavaVisitor[P], p: P) -> J:
734734
return v.visit_block(self, p)
735735

736+
def get_coordinates(self) -> CoordinateBuilder.Block:
737+
return CoordinateBuilder.Block(self)
738+
736739
# noinspection PyShadowingBuiltins,PyShadowingNames,DuplicatedCode
737740
@dataclass(frozen=True, eq=False)
738741
class Break(Statement):
@@ -1289,8 +1292,10 @@ def padding(self) -> PaddingHelper:
12891292
return p
12901293

12911294
def printer(self, cursor: Cursor) -> TreeVisitor[Tree, PrintOutputCapture[P]]:
1292-
factory = PrinterFactory.current()
1293-
return factory.create_printer(cursor) if factory else JavaPrinter[PrintOutputCapture[P]]()
1295+
if factory := PrinterFactory.current():
1296+
return factory.create_printer(cursor)
1297+
from .printer import JavaPrinter
1298+
return JavaPrinter[PrintOutputCapture[P]]()
12941299

12951300
def accept_java(self, v: JavaVisitor[P], p: P) -> J:
12961301
return v.visit_compilation_unit(self, p)

rewrite/rewrite/python/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
# Formatter
7171
'AutoFormat',
7272
'BlankLinesVisitor',
73+
'MinimumViableSpacingVisitor',
7374
'NormalizeFormatVisitor',
7475
'NormalizeTabsOrSpacesVisitor',
7576
'SpacesVisitor',

rewrite/rewrite/python/_parser_visitor.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -422,28 +422,25 @@ def visit_While(self, node):
422422

423423
def visit_If(self, node):
424424
prefix = self.__source_before('if')
425-
single_statement_then_body = len(node.body) == 1
426425
condition = j.ControlParentheses(random_id(), self.__whitespace(), Markers.EMPTY,
427-
self.__pad_right(self.__convert(node.test), self.__source_before(
428-
':') if single_statement_then_body else Space.EMPTY))
429-
then = self.__pad_statement(node.body[0]) if single_statement_then_body else self.__pad_right(
430-
self.__convert_block(node.body), Space.EMPTY)
426+
self.__pad_right(self.__convert(node.test), Space.EMPTY))
427+
then = self.__pad_right(self.__convert_block(node.body), Space.EMPTY)
431428
elze = None
432429
if len(node.orelse) > 0:
433430
else_prefix = self.__whitespace()
434431
if len(node.orelse) == 1 and isinstance(node.orelse[0], ast.If) and self._source.startswith('elif',
435432
self._cursor):
436-
single_statement_else_body = True
433+
is_elif = True
437434
self._cursor += 2
438435
else:
439-
single_statement_else_body = False
436+
is_elif = False
440437
self._cursor += 4
441438

442439
elze = j.If.Else(
443440
random_id(),
444441
else_prefix,
445442
Markers.EMPTY,
446-
self.__pad_statement(node.orelse[0]) if single_statement_else_body else self.__pad_right(
443+
self.__pad_statement(node.orelse[0]) if is_elif else self.__pad_right(
447444
self.__convert_block(node.orelse), Space.EMPTY
448445
)
449446
)
@@ -2067,7 +2064,7 @@ def __next_lexer_token(self, tokens: Iterator[TokenInfo]) -> TokenInfo:
20672064
def __convert_all(self, trees: Sequence) -> List[J2]:
20682065
return [self.__convert(tree) for tree in trees]
20692066

2070-
def __convert_block(self, statements: Sequence, prefix: str = ':') -> j.Block:
2067+
def __convert_block(self, statements: Sequence[Statement], prefix: str = ':') -> j.Block:
20712068
prefix = self.__source_before(prefix)
20722069
if statements:
20732070
statements = [self.__pad_statement(cast(ast.stmt, s)) for s in statements]

rewrite/rewrite/python/format/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
__all__ = [
88
'AutoFormat',
99
'BlankLinesVisitor',
10+
'MinimumViableSpacingVisitor',
1011
'NormalizeFormatVisitor',
1112
'NormalizeTabsOrSpacesVisitor',
1213
'SpacesVisitor',

rewrite/rewrite/python/format/auto_format.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional
22

33
from .blank_lines import BlankLinesVisitor
4+
from .minimum_viable_spacing import MinimumViableSpacingVisitor
45
from .normalize_format import NormalizeFormatVisitor
56
from .normalize_line_breaks_visitor import NormalizeLineBreaksVisitor
67
from .normalize_tabs_or_spaces import NormalizeTabsOrSpacesVisitor
@@ -31,6 +32,8 @@ def visit(self, tree: Optional[Tree], p: P, parent: Optional[Cursor] = None) ->
3132

3233
tree = NormalizeFormatVisitor(self._stop_after).visit(tree, p, self._cursor.fork())
3334

35+
tree = MinimumViableSpacingVisitor(self._stop_after).visit(tree, p, self._cursor.fork())
36+
3437
tree = BlankLinesVisitor(cu.get_style(BlankLinesStyle) or IntelliJ.blank_lines(), self._stop_after).visit(tree, p, self._cursor.fork())
3538

3639
tree = SpacesVisitor(cu.get_style(SpacesStyle) or IntelliJ.spaces(), self._stop_after).visit(tree, p, self._cursor.fork())
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from __future__ import annotations
2+
3+
from typing import cast, Optional, TypeVar, List
4+
5+
from rewrite import Tree, P, Cursor, list_map, list_find
6+
from rewrite.java import J, Space, Statement, Block, Semicolon
7+
from rewrite.python import PythonVisitor, PyComment
8+
from rewrite.visitor import T
9+
10+
J2 = TypeVar('J2', bound=J)
11+
12+
13+
class MinimumViableSpacingVisitor(PythonVisitor):
14+
def __init__(self, stop_after: Optional[Tree] = None):
15+
self._stop_after = stop_after
16+
self._stop = False
17+
18+
def post_visit(self, tree: T, p: P) -> Optional[T]:
19+
if self._stop_after and tree == self._stop_after:
20+
self._stop = True
21+
22+
owner = self.cursor.parent_tree_cursor().value
23+
if isinstance(tree, Statement) and isinstance(owner, Block) and not tree.prefix.comments and not '\n' in tree.prefix.whitespace:
24+
statement_index = list_find(owner.statements, tree)
25+
previous_statement = owner.padding.statements[statement_index - 1] if statement_index > 0 else None
26+
if not previous_statement or not previous_statement.markers.find_first(Semicolon):
27+
tree = tree.with_prefix(tree.prefix.with_whitespace('\n' + tree.prefix.whitespace))
28+
29+
return tree
30+
31+
def visit(self, tree: Optional[Tree], p: P, parent: Optional[Cursor] = None) -> Optional[T]:
32+
return tree if self._stop else super().visit(tree, p, parent)
33+
34+
35+
def _common_margin(s1, s2):
36+
if s1 is None:
37+
s = str(s2)
38+
return s[s.rfind('\n') + 1:]
39+
40+
min_length = min(len(s1), len(s2))
41+
for i in range(min_length):
42+
if s1[i] != s2[i] or not s1[i].isspace():
43+
return s1[:i]
44+
45+
return s2 if len(s2) < len(s1) else s1
46+
47+
48+
def _concatenate_prefix(j: J, prefix: Space) -> J2:
49+
shift = _common_margin(None, j.prefix.whitespace)
50+
51+
def modify_comment(c: PyComment) -> PyComment:
52+
if len(shift) == 0:
53+
return c
54+
c = c.with_text(c.text.replace('\n', '\n' + shift))
55+
if '\n' in c.suffix:
56+
c = c.with_suffix(c.suffix.replace('\n', '\n' + shift))
57+
return c
58+
59+
comments = j.prefix.comments + list_map(modify_comment, cast(List[PyComment], prefix.comments))
60+
61+
new_prefix = j.prefix
62+
new_prefix = new_prefix.with_whitespace(new_prefix.whitespace + prefix.whitespace)
63+
if comments:
64+
new_prefix = new_prefix.with_comments(comments)
65+
66+
return j.with_prefix(new_prefix)

rewrite/rewrite/python/templating.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,11 @@ def visit_expression(self, expression: Expression, p: int) -> J:
150150
return expression
151151

152152
def visit_block(self, block: Block, p: P) -> J:
153+
if self.loc == Space.Location.BLOCK_END and block.is_scope(self.insertion_point):
154+
parsed = self.template_parser.parse_block_statements(Cursor(self.cursor, self.insertion_point), Statement,
155+
self.substituted_template, self.loc, self.mode)
156+
gen: List[Statement] = self.substitutions.unsubstitute_all(parsed)
157+
return self.auto_format(block.with_statements(block.statements + gen), p, self.cursor.parent) if gen else block
153158
if self.loc == Space.Location.STATEMENT_PREFIX:
154159
return self.auto_format(block.with_statements(list_flat_map(lambda s: self.get_replacements(s) if s.is_scope(self.insertion_point) else s, block.statements)), p, self.cursor.parent)
155160
return super().visit_block(block, p)

rewrite/rewrite/python/tree.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,13 @@ def expression(self) -> Expression:
9696
def with_expression(self, expression: Expression) -> Await:
9797
return self if expression is self._expression else replace(self, _expression=expression)
9898

99-
_type: JavaType
99+
_type: Optional[JavaType]
100100

101101
@property
102-
def type(self) -> JavaType:
102+
def type(self) -> Optional[JavaType]:
103103
return self._type
104104

105-
def with_type(self, type: JavaType) -> Await:
105+
def with_type(self, type: Optional[JavaType]) -> Await:
106106
return self if type is self._type else replace(self, _type=type)
107107

108108
def accept_python(self, v: PythonVisitor[P], p: P) -> J:
@@ -340,13 +340,13 @@ def markers(self) -> Markers:
340340
def with_markers(self, markers: Markers) -> ExceptionType:
341341
return self if markers is self._markers else replace(self, _markers=markers)
342342

343-
_type: JavaType
343+
_type: Optional[JavaType]
344344

345345
@property
346-
def type(self) -> JavaType:
346+
def type(self) -> Optional[JavaType]:
347347
return self._type
348348

349-
def with_type(self, type: JavaType) -> ExceptionType:
349+
def with_type(self, type: Optional[JavaType]) -> ExceptionType:
350350
return self if type is self._type else replace(self, _type=type)
351351

352352
_exception_group: bool
@@ -503,13 +503,13 @@ def literal(self) -> Expression:
503503
def with_literal(self, literal: Expression) -> LiteralType:
504504
return self if literal is self._literal else replace(self, _literal=literal)
505505

506-
_type: JavaType
506+
_type: Optional[JavaType]
507507

508508
@property
509-
def type(self) -> JavaType:
509+
def type(self) -> Optional[JavaType]:
510510
return self._type
511511

512-
def with_type(self, type: JavaType) -> LiteralType:
512+
def with_type(self, type: Optional[JavaType]) -> LiteralType:
513513
return self if type is self._type else replace(self, _type=type)
514514

515515
def accept_python(self, v: PythonVisitor[P], p: P) -> J:
@@ -554,13 +554,13 @@ def type_tree(self) -> Expression:
554554
def with_type_tree(self, type_tree: Expression) -> TypeHint:
555555
return self if type_tree is self._type_tree else replace(self, _type_tree=type_tree)
556556

557-
_type: JavaType
557+
_type: Optional[JavaType]
558558

559559
@property
560-
def type(self) -> JavaType:
560+
def type(self) -> Optional[JavaType]:
561561
return self._type
562562

563-
def with_type(self, type: JavaType) -> TypeHint:
563+
def with_type(self, type: Optional[JavaType]) -> TypeHint:
564564
return self if type is self._type else replace(self, _type=type)
565565

566566
def accept_python(self, v: PythonVisitor[P], p: P) -> J:
@@ -703,9 +703,10 @@ def padding(self) -> PaddingHelper:
703703
return p
704704

705705
def printer(self, cursor: Cursor) -> TreeVisitor[Tree, PrintOutputCapture[P]]:
706-
factory = PrinterFactory.current()
706+
if factory := PrinterFactory.current():
707+
return factory.create_printer(cursor)
707708
from .printer import PythonPrinter
708-
return factory.create_printer(cursor) if factory else PythonPrinter[PrintOutputCapture[P]]()
709+
return PythonPrinter[PrintOutputCapture[P]]()
709710

710711
def accept_python(self, v: PythonVisitor[P], p: P) -> J:
711712
return v.visit_compilation_unit(self, p)
@@ -1804,13 +1805,13 @@ def expression(self) -> Expression:
18041805
def with_expression(self, expression: Expression) -> YieldFrom:
18051806
return self if expression is self._expression else replace(self, _expression=expression)
18061807

1807-
_type: JavaType
1808+
_type: Optional[JavaType]
18081809

18091810
@property
1810-
def type(self) -> JavaType:
1811+
def type(self) -> Optional[JavaType]:
18111812
return self._type
18121813

1813-
def with_type(self, type: JavaType) -> YieldFrom:
1814+
def with_type(self, type: Optional[JavaType]) -> YieldFrom:
18141815
return self if type is self._type else replace(self, _type=type)
18151816

18161817
def accept_python(self, v: PythonVisitor[P], p: P) -> J:
@@ -2368,13 +2369,13 @@ def from_(self) -> Expression:
23682369
def with_from(self, from_: Expression) -> ErrorFrom:
23692370
return self.padding.with_from(JLeftPadded.with_element(self._from, from_))
23702371

2371-
_type: JavaType
2372+
_type: Optional[JavaType]
23722373

23732374
@property
2374-
def type(self) -> JavaType:
2375+
def type(self) -> Optional[JavaType]:
23752376
return self._type
23762377

2377-
def with_type(self, type: JavaType) -> ErrorFrom:
2378+
def with_type(self, type: Optional[JavaType]) -> ErrorFrom:
23782379
return self if type is self._type else replace(self, _type=type)
23792380

23802381
@dataclass

0 commit comments

Comments
 (0)