Skip to content

Commit ddc33eb

Browse files
committed
Add ability to run passes after edge in LoweringRecipe
This is neccessary for passes that want to handle nodes not handled by a backend. For example, the cortex_m ReplaceQuantNodesPass wants to only replace quant nodes that have not been delegated. Signed-off-by: Erik Lundell <[email protected]> Change-Id: I4b0856472e0a3b95490c871d58283f6a9be2f4e0
1 parent c40c7ef commit ddc33eb

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

export/recipe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ class LoweringRecipe:
129129
None | List[Callable[[str, ExportedProgram], List[PassType]]]
130130
) = None
131131
# pyre-ignore[11]: Type not defined
132+
post_edge_passes: list[PassType] | None = None
132133
edge_compile_config: Optional[EdgeCompileConfig] = None
133134

134135

export/stages.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,12 @@ def __init__(
177177
transform_passes: (
178178
None | List[Callable[[str, ExportedProgram], List[PassType]]]
179179
) = None,
180+
post_edge_passes: list[PassType] | None = None,
180181
compile_config: Optional[Any] = None,
181182
) -> None:
182183
self._partitioners = partitioners
183184
self._transform_passes = transform_passes
185+
self._post_edge_passes = post_edge_passes
184186
self._compile_config = compile_config
185187

186188
@classmethod
@@ -194,6 +196,7 @@ def from_recipe(
194196
partitioners=lowering_recipe.partitioners,
195197
transform_passes=lowering_recipe.edge_transform_passes,
196198
compile_config=lowering_recipe.edge_compile_config,
199+
post_edge_passes=lowering_recipe.post_edge_passes,
197200
)
198201

199202
@property
@@ -242,6 +245,10 @@ def run(self, artifact: PipelineArtifact) -> None:
242245
compile_config=self._compile_config,
243246
generate_etrecord=generate_etrecord,
244247
)
248+
if self._post_edge_passes:
249+
edge_program_manager = edge_program_manager.transform(
250+
self._post_edge_passes
251+
)
245252

246253
delegation_info = get_delegation_info(
247254
edge_program_manager.exported_program().graph_module
@@ -502,10 +509,12 @@ def __init__(
502509
transform_passes: (
503510
None | List[Callable[[str, ExportedProgram], List[PassType]]]
504511
) = None,
512+
post_edge_passes: list[PassType] | None = None,
505513
) -> None:
506514
super().__init__()
507515
self._partitioners = partitioners
508516
self._transform_passes = transform_passes
517+
self._post_edge_passes = post_edge_passes
509518

510519
@classmethod
511520
def from_recipe(
@@ -517,6 +526,7 @@ def from_recipe(
517526
return cls(
518527
partitioners=lowering_recipe.partitioners,
519528
transform_passes=lowering_recipe.edge_transform_passes,
529+
post_edge_passes=lowering_recipe.post_edge_passes,
520530
)
521531

522532
@property
@@ -569,6 +579,9 @@ def run(self, artifact: PipelineArtifact) -> None:
569579
for partitioner in self._partitioners:
570580
edge_program_manager = edge_program_manager.to_backend(partitioner)
571581

582+
if self._post_edge_passes:
583+
edge_program_manager.transform(self._post_edge_passes)
584+
572585
# Get delegation info
573586
delegation_info = get_delegation_info(
574587
edge_program_manager.exported_program().graph_module

0 commit comments

Comments
 (0)