55from dataclasses import dataclass , field
66from io import StringIO
77from pathlib import Path
8- from typing import Optional , Callable , Iterable , List
8+ from typing import Optional , Callable , Iterable , List , TypeVar
99from uuid import UUID
1010
1111from rewrite import InMemoryExecutionContext , ParserInput , ParserBuilder , random_id , ParseError , ParseExceptionResult , \
12- ExecutionContext , Recipe , TreeVisitor
12+ ExecutionContext , Recipe , TreeVisitor , SourceFile
1313from rewrite .execution import InMemoryLargeSourceSet
14+ from rewrite .python import CompilationUnit
1415from rewrite .python .parser import PythonParserBuilder
1516
1617
18+ S = TypeVar ('S' , bound = SourceFile )
19+
1720@dataclass (frozen = True , eq = False )
1821class SourceSpec :
1922 _id : UUID
@@ -46,6 +49,12 @@ def after(self) -> Optional[Callable[[str], str]]:
4649 def source_path (self ) -> Optional [Path ]:
4750 return self ._source_path
4851
52+ _after_recipe : Callable [[S ], None ] = lambda _ : None
53+
54+ @property
55+ def after_recipe (self ) -> Callable [[S ], None ]:
56+ return self ._after_recipe
57+
4958
5059@dataclass (frozen = True )
5160class CompositeRecipe (Recipe ):
@@ -118,6 +127,7 @@ def rewrite_run(*source_specs: Iterable[SourceSpec], spec: RecipeSpec = None):
118127 for res in result :
119128 if res ._before and res ._after :
120129 source_spec = spec_by_source_file [res ._before ]
130+ source_spec .after_recipe (res ._after )
121131 after_printed = res ._after .print_all ()
122132 if source_spec .after is not None :
123133 after = source_spec .after (after_printed )
@@ -131,13 +141,14 @@ def rewrite_run(*source_specs: Iterable[SourceSpec], spec: RecipeSpec = None):
131141 remoting_context .close ()
132142
133143
134- def python (before : str , after : str = None ) -> list [SourceSpec ]:
144+ def python (before : str , after : str = None , after_recipe : Callable [[ CompilationUnit ], None ] = lambda s : None ) -> list [SourceSpec ]:
135145 return [SourceSpec (
136146 random_id (),
137147 PythonParserBuilder (),
138148 textwrap .dedent (before ),
139149 None if after is None else lambda _ : textwrap .dedent (after ),
140- None
150+ None ,
151+ after_recipe
141152 )]
142153
143154
0 commit comments