Skip to content

Commit 86e61bf

Browse files
authored
[Executorch][Export/ Recipes] Modify pre_edge_transform_passes, edge_transform_passes definition to take exported program and method name. (#14178) (D81730890)
Co-authored-by: Abhinay Kukkadapu <[email protected]> Differential version: D81730890
1 parent 66639e4 commit 86e61bf

File tree

4 files changed

+214
-37
lines changed

4 files changed

+214
-37
lines changed

export/export.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,12 @@ def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]:
195195
elif stage_type == StageType.QUANTIZE:
196196
stage = QuantizeStage(self._quant_recipe)
197197
elif stage_type == StageType.TORCH_EXPORT:
198-
pre_edge_passes = None
199-
if self._export_recipe.pre_edge_transform_passes is not None:
200-
pre_edge_passes = list(
201-
self._export_recipe.pre_edge_transform_passes
198+
aten_transform_passes = None
199+
if self._export_recipe.aten_transform_passes is not None:
200+
aten_transform_passes = list(
201+
self._export_recipe.aten_transform_passes
202202
)
203-
stage = TorchExportStage(pre_edge_passes)
203+
stage = TorchExportStage(aten_transform_passes)
204204
elif stage_type == StageType.TO_EDGE_TRANSFORM_AND_LOWER:
205205
stage = EdgeTransformAndLowerStage.from_recipe(self._lowering_recipe)
206206
elif stage_type == StageType.TO_EDGE:

export/recipe.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from abc import ABCMeta, abstractmethod
88
from dataclasses import dataclass
99
from enum import Enum, EnumMeta
10-
from typing import Callable, List, Optional, Sequence
10+
from typing import Callable, List, Optional
1111

1212
import torch
13+
from executorch.exir import ExportedProgram
1314

1415
from executorch.exir._warnings import experimental
1516

@@ -117,12 +118,15 @@ class LoweringRecipe:
117118
118119
Attributes:
119120
partitioners: Optional list of partitioners for model partitioning
120-
edge_transform_passes: Optional sequence of transformation passes to apply
121+
edge_transform_passes: Optional list of callables that take (method_name: str, exported_program: ExportedProgram) as arguments
122+
and return a list of passes (PassType) to be executed during lowering stages.
121123
edge_compile_config: Optional edge compilation configuration
122124
"""
123125

124126
partitioners: Optional[List[Partitioner]] = None
125-
edge_transform_passes: Optional[Sequence[PassType]] = None
127+
edge_transform_passes: (
128+
None | List[Callable[[str, ExportedProgram], List[PassType]]]
129+
) = None
126130
# pyre-ignore[11]: Type not defined
127131
edge_compile_config: Optional[EdgeCompileConfig] = None
128132

@@ -141,8 +145,8 @@ class ExportRecipe:
141145
Attributes:
142146
name: Optional name for the recipe
143147
quantization_recipe: Optional quantization recipe for model quantization
144-
pre_edge_transform_passes: Optional function to apply transformation passes
145-
before edge lowering
148+
aten_transform_passes: Optional list of functions to apply transformation passes to the program before edge lowering.
149+
These callables are invoked to modify and return the transformed program.
146150
lowering_recipe: Optional lowering recipe for model lowering and partitioning
147151
executorch_backend_config: Optional backend configuration for ExecuTorch
148152
pipeline_stages: Optional list of stages to execute, defaults to a standard pipeline.
@@ -151,7 +155,9 @@ class ExportRecipe:
151155

152156
name: Optional[str] = None
153157
quantization_recipe: Optional[QuantizationRecipe] = None
154-
pre_edge_transform_passes: Optional[Sequence[PassType]] = None
158+
aten_transform_passes: Optional[
159+
List[Callable[[str, ExportedProgram], ExportedProgram]]
160+
] = None
155161
lowering_recipe: Optional[LoweringRecipe] = None
156162
# pyre-ignore[11]: Type not defined
157163
executorch_backend_config: Optional[ExecutorchBackendConfig] = None
@@ -240,8 +246,8 @@ def _combine_recipes( # noqa: C901
240246

241247
for recipe in backend_recipes:
242248
# Collect pre-edge transform passes
243-
if recipe.pre_edge_transform_passes:
244-
all_pre_edge_passes.extend(recipe.pre_edge_transform_passes)
249+
if recipe.aten_transform_passes:
250+
all_pre_edge_passes.extend(recipe.aten_transform_passes)
245251

246252
# Collect partitioners from lowering recipes
247253
if recipe.lowering_recipe and recipe.lowering_recipe.partitioners:
@@ -307,7 +313,7 @@ def _combine_recipes( # noqa: C901
307313
return cls(
308314
name=recipe_name,
309315
quantization_recipe=combined_quantization_recipe,
310-
pre_edge_transform_passes=all_pre_edge_passes,
316+
aten_transform_passes=all_pre_edge_passes,
311317
lowering_recipe=combined_lowering_recipe,
312318
executorch_backend_config=combined_backend_config,
313319
)

export/stages.py

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
import copy
88
import logging
99
from abc import ABC, abstractmethod
10-
from typing import Any, Callable, Dict, List, Optional, Sequence
10+
from collections import defaultdict
11+
from typing import Any, Callable, Dict, List, Optional
1112

1213
import torch
1314
from executorch.devtools.backend_debug import get_delegation_info
14-
from executorch.exir import EdgeCompileConfig
15+
from executorch.exir import EdgeCompileConfig, ExportedProgram
1516
from executorch.exir.backend.backend_api import validation_disabled
1617
from executorch.exir.program import to_edge, to_edge_transform_and_lower
17-
from executorch.exir.program._program import _transform
1818
from executorch.export.recipe import LoweringRecipe, QuantizationRecipe
1919
from executorch.export.types import StageType
2020
from torch import nn
@@ -107,10 +107,12 @@ class TorchExportStage(Stage):
107107

108108
def __init__(
109109
self,
110-
pre_edge_transform_passes: Optional[List[PassType]] = None,
110+
aten_transform_passes: Optional[
111+
List[Callable[[str, ExportedProgram], ExportedProgram]]
112+
] = None,
111113
) -> None:
112114
super().__init__()
113-
self._pre_edge_transform_passes = pre_edge_transform_passes
115+
self._aten_transform_passes = aten_transform_passes
114116

115117
@property
116118
def stage_type(self) -> str:
@@ -149,9 +151,13 @@ def run(self, artifact: PipelineArtifact) -> None:
149151
)
150152

151153
# Apply pre-edge transform passes if available
152-
for pass_ in self._pre_edge_transform_passes or []:
153-
exported_programs[method_name] = _transform(
154-
exported_programs[method_name], pass_
154+
for pass_ in self._aten_transform_passes or []:
155+
if not callable(pass_):
156+
raise ValueError(
157+
"Aten transform passes must be a callable that can transform and return an exported program"
158+
)
159+
exported_programs[method_name] = pass_(
160+
method_name, exported_programs[method_name]
155161
)
156162

157163
self._artifact = artifact.copy_with_new_data(exported_programs)
@@ -165,7 +171,9 @@ class EdgeTransformAndLowerStage(Stage):
165171
def __init__(
166172
self,
167173
partitioners: Optional[List[Any]] = None,
168-
transform_passes: Optional[Sequence[Callable[[Any], Optional[Any]]]] = None,
174+
transform_passes: (
175+
None | List[Callable[[str, ExportedProgram], List[PassType]]]
176+
) = None,
169177
compile_config: Optional[Any] = None,
170178
) -> None:
171179
self._partitioners = partitioners
@@ -205,11 +213,28 @@ def run(self, artifact: PipelineArtifact) -> None:
205213
constant_methods = artifact.get_context("constant_methods")
206214
generate_etrecord = artifact.get_context("generate_etrecord", False)
207215

216+
# per method transform passes
217+
transform_passes = defaultdict(list)
218+
for method_name, ep in exported_programs.items():
219+
# Resolve transform passes from callable
220+
for pass_ in self._transform_passes or []:
221+
if not callable(pass_):
222+
raise ValueError(
223+
"Transform passes must be a callable that resolves to a list of passes"
224+
)
225+
passes = pass_(method_name, ep)
226+
if isinstance(passes, list):
227+
transform_passes[method_name].extend(passes)
228+
else:
229+
raise ValueError(
230+
"Transform passes must be a callable that resolves to a list of passes"
231+
)
232+
208233
with validation_disabled():
209234
edge_program_manager = to_edge_transform_and_lower(
210235
exported_programs,
211236
partitioner=self._partitioners,
212-
transform_passes=self._transform_passes,
237+
transform_passes=transform_passes,
213238
constant_methods=constant_methods,
214239
compile_config=self._compile_config,
215240
generate_etrecord=generate_etrecord,
@@ -396,7 +421,7 @@ def run(self, artifact: PipelineArtifact) -> None:
396421
captured_graph = torch.export.export(model, inputs, strict=True).module()
397422

398423
quantizer = self._get_quantizer_for_prepare_pt2e(
399-
self._quantization_recipe.quantizers
424+
self._quantization_recipe.quantizers # pyre-ignore
400425
)
401426
prepared_model = prepare_pt2e(captured_graph, quantizer)
402427

@@ -471,7 +496,9 @@ class ToBackendStage(Stage):
471496
def __init__(
472497
self,
473498
partitioners: Optional[List[Any]] = None,
474-
transform_passes: Optional[Sequence[Callable[[Any], Optional[Any]]]] = None,
499+
transform_passes: (
500+
None | List[Callable[[str, ExportedProgram], List[PassType]]]
501+
) = None,
475502
) -> None:
476503
super().__init__()
477504
self._partitioners = partitioners
@@ -513,11 +540,24 @@ def run(self, artifact: PipelineArtifact) -> None:
513540
if edge_program_manager is None:
514541
raise RuntimeError("Edge program manager is not set.")
515542

516-
# Apply transform passes if available
517-
if self._transform_passes:
518-
edge_program_manager = edge_program_manager.transform(
519-
self._transform_passes
520-
)
543+
# per method transform passes
544+
transform_passes = defaultdict(list)
545+
for method_name in edge_program_manager.methods:
546+
# Resolve transform passes if it's a callable
547+
ep = edge_program_manager.exported_program(method_name)
548+
for pass_ in self._transform_passes or []:
549+
if not callable(pass_):
550+
raise ValueError(
551+
"Transform passes must be a callable that resolves to a list of passes"
552+
)
553+
passes = pass_(method_name, ep)
554+
if isinstance(passes, list):
555+
transform_passes[method_name].extend(passes)
556+
else:
557+
raise ValueError("Transform passes must return list of passes")
558+
559+
# Apply transform passes
560+
edge_program_manager = edge_program_manager.transform(transform_passes)
521561

522562
# Apply partitioners if available
523563
if self._partitioners is not None and len(self._partitioners) > 0:

0 commit comments

Comments
 (0)