Skip to content

Commit 9751f1c

Browse files
[Executorch][Export][3/N] Modularize export recipes
Pull Request resolved: #12938 Addresses (7) in the rfc: #12660 Changes: 1. Add data class called `LoweringRecipe` 2. Modify current xnnpack recipes to use lowering recipes Fixes: #12933 Differential Revision: [D79120575](https://our.internmc.facebook.com/intern/diff/D79120575/) ghstack-source-id: 299637528
1 parent 9430abd commit 9751f1c

File tree

5 files changed

+122
-27
lines changed

5 files changed

+122
-27
lines changed

backends/xnnpack/recipes/xnnpack_recipe_provider.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from executorch.export import (
2828
BackendRecipeProvider,
2929
ExportRecipe,
30+
LoweringRecipe,
3031
QuantizationRecipe,
3132
RecipeType,
3233
)
@@ -88,12 +89,19 @@ def create_recipe(
8889
)
8990
return None
9091

92+
def _get_xnnpack_lowering_recipe(
93+
self, precision_type: Optional[ConfigPrecisionType] = None
94+
) -> LoweringRecipe:
95+
return LoweringRecipe(
96+
partitioners=[XnnpackPartitioner(precision_type=precision_type)],
97+
edge_compile_config=get_xnnpack_edge_compile_config(),
98+
)
99+
91100
def _build_fp32_recipe(self, recipe_type: RecipeType) -> ExportRecipe:
92101
return ExportRecipe(
93102
name=recipe_type.value,
94-
edge_compile_config=get_xnnpack_edge_compile_config(),
103+
lowering_recipe=self._get_xnnpack_lowering_recipe(),
95104
executorch_backend_config=get_xnnpack_executorch_backend_config(),
96-
partitioners=[XnnpackPartitioner()],
97105
)
98106

99107
def _build_quantized_recipe(
@@ -120,9 +128,8 @@ def _build_quantized_recipe(
120128
return ExportRecipe(
121129
name=recipe_type.value,
122130
quantization_recipe=quant_recipe,
123-
edge_compile_config=get_xnnpack_edge_compile_config(),
131+
lowering_recipe=self._get_xnnpack_lowering_recipe(precision_type),
124132
executorch_backend_config=get_xnnpack_executorch_backend_config(),
125-
partitioners=[XnnpackPartitioner(config_precision=precision_type)],
126133
)
127134

128135
def _build_int8da_intx_weight_recipe(
@@ -150,9 +157,8 @@ def _build_int8da_intx_weight_recipe(
150157
return ExportRecipe(
151158
name=recipe_type.value,
152159
quantization_recipe=quant_recipe,
153-
edge_compile_config=get_xnnpack_edge_compile_config(),
160+
lowering_recipe=self._get_xnnpack_lowering_recipe(),
154161
executorch_backend_config=get_xnnpack_executorch_backend_config(),
155-
partitioners=[XnnpackPartitioner()],
156162
)
157163

158164
def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None:

export/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
"""
1616

1717
from .export import export, ExportSession
18-
from .recipe import ExportRecipe, QuantizationRecipe, RecipeType
18+
from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe, RecipeType
1919
from .recipe_provider import BackendRecipeProvider
2020
from .recipe_registry import recipe_registry
21-
21+
from .stages import StageType
2222

2323
__all__ = [
24+
"StageType",
2425
"ExportRecipe",
26+
"LoweringRecipe",
2527
"QuantizationRecipe",
2628
"ExportSession",
2729
"export",

export/export.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from tabulate import tabulate
1717
from torch import nn
1818

19-
from .recipe import ExportRecipe, QuantizationRecipe
19+
from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe
2020
from .stages import (
2121
EdgeTransformAndLowerStage,
2222
ExecutorchStage,
@@ -145,6 +145,10 @@ def __init__(
145145
self._export_recipe.quantization_recipe
146146
)
147147

148+
self._lowering_recipe: Optional[LoweringRecipe] = (
149+
self._export_recipe.lowering_recipe
150+
)
151+
148152
# Stages to run
149153
self._pipeline_stages = pipeline_stages or self._get_default_pipeline()
150154

@@ -189,18 +193,42 @@ def _build_default_stages(self) -> Dict[StageType, Stage]:
189193
stage = TorchExportStage(pre_edge_passes)
190194
elif stage_type == StageType.TO_EDGE_TRANSFORM_AND_LOWER:
191195
stage = EdgeTransformAndLowerStage(
192-
partitioners=self._export_recipe.partitioners,
193-
transform_passes=self._export_recipe.edge_transform_passes,
194-
compile_config=self._export_recipe.edge_compile_config,
196+
partitioners=(
197+
self._lowering_recipe.partitioners
198+
if self._lowering_recipe
199+
else None
200+
),
201+
transform_passes=(
202+
self._lowering_recipe.edge_transform_passes
203+
if self._lowering_recipe
204+
else None
205+
),
206+
compile_config=(
207+
self._lowering_recipe.edge_compile_config
208+
if self._lowering_recipe
209+
else None
210+
),
195211
)
196212
elif stage_type == StageType.TO_EDGE:
197213
stage = ToEdgeStage(
198-
edge_compile_config=self._export_recipe.edge_compile_config
214+
edge_compile_config=(
215+
self._lowering_recipe.edge_compile_config
216+
if self._lowering_recipe
217+
else None
218+
),
199219
)
200220
elif stage_type == StageType.TO_BACKEND:
201221
stage = ToBackendStage(
202-
partitioners=self._export_recipe.partitioners,
203-
transform_passes=self._export_recipe.edge_transform_passes,
222+
partitioners=(
223+
self._lowering_recipe.partitioners
224+
if self._lowering_recipe
225+
else None
226+
),
227+
transform_passes=(
228+
self._lowering_recipe.edge_transform_passes
229+
if self._lowering_recipe
230+
else None
231+
),
204232
)
205233
elif stage_type == StageType.TO_EXECUTORCH:
206234
stage = ExecutorchStage(self._export_recipe.executorch_backend_config)

export/recipe.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,26 @@ def get_quantizers(self) -> Optional[List[Quantizer]]:
8787
return self.quantizers
8888

8989

90+
@dataclass
91+
class LoweringRecipe:
92+
"""
93+
Configuration recipe for lowering and partitioning.
94+
95+
This class holds the configuration parameters for lowering a model
96+
to backend-specific representations.
97+
98+
Attributes:
99+
partitioners: Optional list of partitioners for model partitioning
100+
edge_transform_passes: Optional sequence of transformation passes to apply
101+
edge_compile_config: Optional edge compilation configuration
102+
"""
103+
104+
partitioners: Optional[List[Partitioner]] = None
105+
edge_transform_passes: Optional[Sequence[PassType]] = None
106+
# pyre-ignore[11]: Type not defined
107+
edge_compile_config: Optional[EdgeCompileConfig] = None
108+
109+
90110
@experimental(
91111
"This API and all of its related functionality such as ExportSession and ExportRecipe are experimental."
92112
)
@@ -101,27 +121,19 @@ class ExportRecipe:
101121
Attributes:
102122
name: Optional name for the recipe
103123
quantization_recipe: Optional quantization recipe for model quantization
104-
edge_compile_config: Optional edge compilation configuration
124+
lowering_recipe: Optional lowering recipe for model lowering and partitioning
125+
executorch_backend_config: Optional backend configuration for ExecuTorch
105126
pre_edge_transform_passes: Optional function to apply transformation passes
106127
before edge lowering
107-
edge_transform_passes: Optional sequence of transformation passes to apply
108-
during edge lowering
109-
transform_check_ir_validity: Whether to check IR validity during transformation
110-
partitioners: Optional list of partitioners for model partitioning
111-
executorch_backend_config: Optional backend configuration for ExecuTorch
112128
mode: Export mode (debug or release)
113129
"""
114130

115131
name: Optional[str] = None
116132
quantization_recipe: Optional[QuantizationRecipe] = None
117-
# pyre-ignore[11]: Type not defined
118-
edge_compile_config: Optional[EdgeCompileConfig] = None
119-
pre_edge_transform_passes: Optional[Sequence[PassType]] = None
120-
edge_transform_passes: Optional[Sequence[PassType]] = None
121-
transform_check_ir_validity: bool = True
122-
partitioners: Optional[List[Partitioner]] = None
133+
lowering_recipe: Optional[LoweringRecipe] = None
123134
# pyre-ignore[11]: Type not defined
124135
executorch_backend_config: Optional[ExecutorchBackendConfig] = None
136+
pre_edge_transform_passes: Optional[Sequence[PassType]] = None
125137
mode: Mode = Mode.RELEASE
126138

127139
@classmethod

export/tests/test_export_session.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch
1313
from executorch.export import ExportRecipe, ExportSession
14+
from executorch.export.recipe import LoweringRecipe, QuantizationRecipe
1415
from executorch.export.stages import PipelineArtifact, StageType
1516

1617

@@ -329,3 +330,49 @@ def test_save_to_pte_invalid_name(self) -> None:
329330

330331
with self.assertRaises(AssertionError):
331332
session.save_to_pte(None) # pyre-ignore
333+
334+
335+
class TestExportSessionPipelineBuilding(unittest.TestCase):
336+
"""Test pipeline building and stage configuration."""
337+
338+
def setUp(self) -> None:
339+
self.model = SimpleTestModel()
340+
self.example_inputs = [(torch.randn(2, 10),)]
341+
342+
def test_pipeline_building_with_all_recipes(self) -> None:
343+
"""Test pipeline building with quantization and lowering recipes."""
344+
# Create comprehensive recipes
345+
quant_recipe = QuantizationRecipe(
346+
ao_base_config=[Mock()],
347+
quantizers=[Mock()],
348+
)
349+
lowering_recipe = LoweringRecipe(
350+
partitioners=[Mock()],
351+
edge_transform_passes=[Mock()],
352+
edge_compile_config=Mock(),
353+
)
354+
recipe = ExportRecipe(
355+
name="comprehensive_test",
356+
quantization_recipe=quant_recipe,
357+
lowering_recipe=lowering_recipe,
358+
pre_edge_transform_passes=[Mock()],
359+
executorch_backend_config=Mock(),
360+
)
361+
362+
session = ExportSession(
363+
model=self.model,
364+
example_inputs=self.example_inputs,
365+
export_recipe=recipe,
366+
)
367+
368+
registered_stages = session.get_all_registered_stages()
369+
370+
self.assertEqual(len(registered_stages), 5)
371+
expected_types = [
372+
StageType.SOURCE_TRANSFORM,
373+
StageType.QUANTIZE,
374+
StageType.TORCH_EXPORT,
375+
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
376+
StageType.TO_EXECUTORCH,
377+
]
378+
self.assertListEqual(list(registered_stages.keys()), expected_types)

0 commit comments

Comments
 (0)