Skip to content

Commit dc7bdd9

Browse files
[Executorch][Export][3/N] Modularize export recipes
Pull Request resolved: #12938 Addresses (7) in the rfc: #12660 Changes: 1. Add data class called `LoweringRecipe` 2. Modify current xnnpack recipes to use lowering recipes Fixes: #12933 ghstack-source-id: 299968042 Differential Revision: [D79120575](https://our.internmc.facebook.com/intern/diff/D79120575/)
1 parent 10c502c commit dc7bdd9

File tree

5 files changed

+117
-24
lines changed

5 files changed

+117
-24
lines changed

backends/xnnpack/recipes/xnnpack_recipe_provider.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from executorch.export import (
2828
BackendRecipeProvider,
2929
ExportRecipe,
30+
LoweringRecipe,
3031
QuantizationRecipe,
3132
RecipeType,
3233
)
@@ -88,12 +89,19 @@ def create_recipe(
8889
)
8990
return None
9091

92+
def _get_xnnpack_lowering_recipe(
93+
self, precision_type: Optional[ConfigPrecisionType] = None
94+
) -> LoweringRecipe:
95+
return LoweringRecipe(
96+
partitioners=[XnnpackPartitioner(precision_type=precision_type)],
97+
edge_compile_config=get_xnnpack_edge_compile_config(),
98+
)
99+
91100
def _build_fp32_recipe(self, recipe_type: RecipeType) -> ExportRecipe:
92101
return ExportRecipe(
93102
name=recipe_type.value,
94-
edge_compile_config=get_xnnpack_edge_compile_config(),
103+
lowering_recipe=self._get_xnnpack_lowering_recipe(),
95104
executorch_backend_config=get_xnnpack_executorch_backend_config(),
96-
partitioners=[XnnpackPartitioner()],
97105
)
98106

99107
def _build_quantized_recipe(
@@ -120,9 +128,8 @@ def _build_quantized_recipe(
120128
return ExportRecipe(
121129
name=recipe_type.value,
122130
quantization_recipe=quant_recipe,
123-
edge_compile_config=get_xnnpack_edge_compile_config(),
131+
lowering_recipe=self._get_xnnpack_lowering_recipe(precision_type),
124132
executorch_backend_config=get_xnnpack_executorch_backend_config(),
125-
partitioners=[XnnpackPartitioner(config_precision=precision_type)],
126133
)
127134

128135
def _build_int8da_intx_weight_recipe(
@@ -150,9 +157,8 @@ def _build_int8da_intx_weight_recipe(
150157
return ExportRecipe(
151158
name=recipe_type.value,
152159
quantization_recipe=quant_recipe,
153-
edge_compile_config=get_xnnpack_edge_compile_config(),
160+
lowering_recipe=self._get_xnnpack_lowering_recipe(),
154161
executorch_backend_config=get_xnnpack_executorch_backend_config(),
155-
partitioners=[XnnpackPartitioner()],
156162
)
157163

158164
def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None:

export/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
"""
1616

1717
from .export import export, ExportSession
18-
from .recipe import ExportRecipe, QuantizationRecipe, RecipeType
18+
from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe, RecipeType
1919
from .recipe_provider import BackendRecipeProvider
2020
from .recipe_registry import recipe_registry
2121
from .types import StageType
2222

2323
__all__ = [
2424
"StageType",
2525
"ExportRecipe",
26+
"LoweringRecipe",
2627
"QuantizationRecipe",
2728
"ExportSession",
2829
"export",

export/export.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from tabulate import tabulate
1717
from torch import nn
1818

19-
from .recipe import ExportRecipe, QuantizationRecipe
19+
from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe
2020
from .stages import (
2121
EdgeTransformAndLowerStage,
2222
ExecutorchStage,
@@ -143,6 +143,10 @@ def __init__(
143143
self._export_recipe.quantization_recipe
144144
)
145145

146+
self._lowering_recipe: Optional[LoweringRecipe] = (
147+
self._export_recipe.lowering_recipe
148+
)
149+
146150
# Stages to run
147151
self._pipeline_stages = (
148152
self._export_recipe.pipeline_stages or self._get_default_pipeline()
@@ -193,18 +197,42 @@ def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]:
193197
stage = TorchExportStage(pre_edge_passes)
194198
elif stage_type == StageType.TO_EDGE_TRANSFORM_AND_LOWER:
195199
stage = EdgeTransformAndLowerStage(
196-
partitioners=self._export_recipe.partitioners,
197-
transform_passes=self._export_recipe.edge_transform_passes,
198-
compile_config=self._export_recipe.edge_compile_config,
200+
partitioners=(
201+
self._lowering_recipe.partitioners
202+
if self._lowering_recipe
203+
else None
204+
),
205+
transform_passes=(
206+
self._lowering_recipe.edge_transform_passes
207+
if self._lowering_recipe
208+
else None
209+
),
210+
compile_config=(
211+
self._lowering_recipe.edge_compile_config
212+
if self._lowering_recipe
213+
else None
214+
),
199215
)
200216
elif stage_type == StageType.TO_EDGE:
201217
stage = ToEdgeStage(
202-
edge_compile_config=self._export_recipe.edge_compile_config
218+
edge_compile_config=(
219+
self._lowering_recipe.edge_compile_config
220+
if self._lowering_recipe
221+
else None
222+
),
203223
)
204224
elif stage_type == StageType.TO_BACKEND:
205225
stage = ToBackendStage(
206-
partitioners=self._export_recipe.partitioners,
207-
transform_passes=self._export_recipe.edge_transform_passes,
226+
partitioners=(
227+
self._lowering_recipe.partitioners
228+
if self._lowering_recipe
229+
else None
230+
),
231+
transform_passes=(
232+
self._lowering_recipe.edge_transform_passes
233+
if self._lowering_recipe
234+
else None
235+
),
208236
)
209237
elif stage_type == StageType.TO_EXECUTORCH:
210238
stage = ExecutorchStage(self._export_recipe.executorch_backend_config)

export/recipe.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,26 @@ def get_quantizers(self) -> Optional[List[Quantizer]]:
8989
return self.quantizers
9090

9191

92+
@dataclass
93+
class LoweringRecipe:
94+
"""
95+
Configuration recipe for lowering and partitioning.
96+
97+
This class holds the configuration parameters for lowering a model
98+
to backend-specific representations.
99+
100+
Attributes:
101+
partitioners: Optional list of partitioners for model partitioning
102+
edge_transform_passes: Optional sequence of transformation passes to apply
103+
edge_compile_config: Optional edge compilation configuration
104+
"""
105+
106+
partitioners: Optional[List[Partitioner]] = None
107+
edge_transform_passes: Optional[Sequence[PassType]] = None
108+
# pyre-ignore[11]: Type not defined
109+
edge_compile_config: Optional[EdgeCompileConfig] = None
110+
111+
92112
@experimental(
93113
"This API and all of its related functionality such as ExportSession and ExportRecipe are experimental."
94114
)
@@ -103,26 +123,18 @@ class ExportRecipe:
103123
Attributes:
104124
name: Optional name for the recipe
105125
quantization_recipe: Optional quantization recipe for model quantization
106-
edge_compile_config: Optional edge compilation configuration
107126
pre_edge_transform_passes: Optional function to apply transformation passes
108127
before edge lowering
109-
edge_transform_passes: Optional sequence of transformation passes to apply
110-
during edge lowering
111-
transform_check_ir_validity: Whether to check IR validity during transformation
112-
partitioners: Optional list of partitioners for model partitioning
128+
lowering_recipe: Optional lowering recipe for model lowering and partitioning
113129
executorch_backend_config: Optional backend configuration for ExecuTorch
114130
pipeline_stages: Optional list of stages to execute, defaults to a standard pipeline.
115131
mode: Export mode (debug or release)
116132
"""
117133

118134
name: Optional[str] = None
119135
quantization_recipe: Optional[QuantizationRecipe] = None
120-
# pyre-ignore[11]: Type not defined
121-
edge_compile_config: Optional[EdgeCompileConfig] = None
122136
pre_edge_transform_passes: Optional[Sequence[PassType]] = None
123-
edge_transform_passes: Optional[Sequence[PassType]] = None
124-
transform_check_ir_validity: bool = True
125-
partitioners: Optional[List[Partitioner]] = None
137+
lowering_recipe: Optional[LoweringRecipe] = None
126138
# pyre-ignore[11]: Type not defined
127139
executorch_backend_config: Optional[ExecutorchBackendConfig] = None
128140
pipeline_stages: Optional[List[StageType]] = None

export/tests/test_export_session.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import torch
1414
from executorch.export import ExportRecipe, ExportSession
15+
from executorch.export.recipe import LoweringRecipe, QuantizationRecipe
1516
from executorch.export.stages import PipelineArtifact
1617
from executorch.export.types import StageType
1718

@@ -434,3 +435,48 @@ def test_save_to_pte_invalid_name(self) -> None:
434435

435436
with self.assertRaises(AssertionError):
436437
session.save_to_pte(None) # pyre-ignore
438+
439+
440+
class TestExportSessionPipelineBuilding(unittest.TestCase):
441+
"""Test pipeline building and stage configuration."""
442+
443+
def setUp(self) -> None:
444+
self.model = SimpleTestModel()
445+
self.example_inputs = [(torch.randn(2, 10),)]
446+
447+
def test_pipeline_building_with_all_recipes(self) -> None:
448+
"""Test pipeline building with quantization and lowering recipes."""
449+
# Create comprehensive recipes
450+
quant_recipe = QuantizationRecipe(
451+
ao_base_config=[Mock()],
452+
quantizers=[Mock()],
453+
)
454+
lowering_recipe = LoweringRecipe(
455+
partitioners=[Mock()],
456+
edge_transform_passes=[Mock()],
457+
edge_compile_config=Mock(),
458+
)
459+
recipe = ExportRecipe(
460+
name="comprehensive_test",
461+
quantization_recipe=quant_recipe,
462+
lowering_recipe=lowering_recipe,
463+
executorch_backend_config=Mock(),
464+
)
465+
466+
session = ExportSession(
467+
model=self.model,
468+
example_inputs=self.example_inputs,
469+
export_recipe=recipe,
470+
)
471+
472+
registered_stages = session.get_all_registered_stages()
473+
474+
self.assertEqual(len(registered_stages), 5)
475+
expected_types = [
476+
StageType.SOURCE_TRANSFORM,
477+
StageType.QUANTIZE,
478+
StageType.TORCH_EXPORT,
479+
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
480+
StageType.TO_EXECUTORCH,
481+
]
482+
self.assertListEqual(list(registered_stages.keys()), expected_types)

0 commit comments

Comments
 (0)