Skip to content
This repository was archived by the owner on Jan 13, 2026. It is now read-only.

Commit 805ad8a

Browse files
committed
Allow source spec to include after recipe
1 parent 001375b commit 805ad8a

File tree

3 files changed

+23
-6
lines changed

3 files changed

+23
-6
lines changed

rewrite/rewrite/python/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
'Await',
2828
'ChainedAssignment',
2929
'CollectionLiteral',
30+
'CompilationUnit',
3031
'ComprehensionExpression',
3132
'Del',
3233
'DictLiteral',

rewrite/rewrite/test.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,18 @@
55
from dataclasses import dataclass, field
66
from io import StringIO
77
from pathlib import Path
8-
from typing import Optional, Callable, Iterable, List
8+
from typing import Optional, Callable, Iterable, List, TypeVar
99
from uuid import UUID
1010

1111
from rewrite import InMemoryExecutionContext, ParserInput, ParserBuilder, random_id, ParseError, ParseExceptionResult, \
12-
ExecutionContext, Recipe, TreeVisitor
12+
ExecutionContext, Recipe, TreeVisitor, SourceFile
1313
from rewrite.execution import InMemoryLargeSourceSet
14+
from rewrite.python import CompilationUnit
1415
from rewrite.python.parser import PythonParserBuilder
1516

1617

18+
S = TypeVar('S', bound=SourceFile)
19+
1720
@dataclass(frozen=True, eq=False)
1821
class SourceSpec:
1922
_id: UUID
@@ -46,6 +49,12 @@ def after(self) -> Optional[Callable[[str], str]]:
4649
def source_path(self) -> Optional[Path]:
4750
return self._source_path
4851

52+
_after_recipe: Callable[[S], None] = lambda _: None
53+
54+
@property
55+
def after_recipe(self) -> Callable[[S], None]:
56+
return self._after_recipe
57+
4958

5059
@dataclass(frozen=True)
5160
class CompositeRecipe(Recipe):
@@ -118,6 +127,7 @@ def rewrite_run(*source_specs: Iterable[SourceSpec], spec: RecipeSpec = None):
118127
for res in result:
119128
if res._before and res._after:
120129
source_spec = spec_by_source_file[res._before]
130+
source_spec.after_recipe(res._after)
121131
after_printed = res._after.print_all()
122132
if source_spec.after is not None:
123133
after = source_spec.after(after_printed)
@@ -131,13 +141,14 @@ def rewrite_run(*source_specs: Iterable[SourceSpec], spec: RecipeSpec = None):
131141
remoting_context.close()
132142

133143

134-
def python(before: str, after: str = None) -> list[SourceSpec]:
144+
def python(before: str, after: str = None, after_recipe: Callable[[CompilationUnit], None] = lambda s: None) -> list[SourceSpec]:
135145
return [SourceSpec(
136146
random_id(),
137147
PythonParserBuilder(),
138148
textwrap.dedent(before),
139149
None if after is None else lambda _: textwrap.dedent(after),
140-
None
150+
None,
151+
after_recipe
141152
)]
142153

143154

rewrite/tests/python/all/format/normalize_format_test.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from rewrite.python import NormalizeFormatVisitor, PythonVisitor
1+
from rewrite.java import Space
2+
from rewrite.python import NormalizeFormatVisitor, PythonVisitor, CompilationUnit
23
from rewrite.test import rewrite_run, python, RecipeSpec, from_visitor
34

45

@@ -8,6 +9,9 @@ def visit_method_declaration(self, method, p):
89

910

1011
def test_remove_decorator():
12+
def assert_prefix(cu: CompilationUnit):
13+
assert cu.statements[1].prefix == Space([], '\n')
14+
1115
rewrite_run(
1216
# language=python
1317
python(
@@ -24,7 +28,8 @@ def f(n):
2428
2529
def f(n):
2630
return n
27-
"""
31+
""",
32+
assert_prefix
2833
),
2934
spec=RecipeSpec()
3035
.with_recipes(

0 commit comments

Comments
 (0)