Skip to content

Commit 70d44ba

Browse files
[Executorch][Export][3/N] Modularize export recipes
Addresses (7) in the rfc: #12660 Changes: 1. Add data class called `LoweringRecipe` 2. Modify current xnnpack recipes to use lowering recipes Fixes: #12933 Differential Revision: [D79120575](https://our.internmc.facebook.com/intern/diff/D79120575/) ghstack-source-id: 299106077 Pull Request resolved: #12938
1 parent 66e7c4a commit 70d44ba

File tree

5 files changed

+121
-27
lines changed

5 files changed

+121
-27
lines changed

backends/xnnpack/recipes/xnnpack_recipe_provider.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from executorch.export import (
2828
BackendRecipeProvider,
2929
ExportRecipe,
30+
LoweringRecipe,
3031
QuantizationRecipe,
3132
RecipeType,
3233
)
@@ -88,12 +89,19 @@ def create_recipe(
8889
)
8990
return None
9091

92+
def _get_xnnpack_lowering_recipe(
93+
self, precision_type: Optional[ConfigPrecisionType] = None
94+
) -> LoweringRecipe:
95+
return LoweringRecipe(
96+
partitioners=[XnnpackPartitioner(precision_type=precision_type)],
97+
edge_compile_config=get_xnnpack_edge_compile_config(),
98+
)
99+
91100
def _build_fp32_recipe(self, recipe_type: RecipeType) -> ExportRecipe:
92101
return ExportRecipe(
93102
name=recipe_type.value,
94-
edge_compile_config=get_xnnpack_edge_compile_config(),
103+
lowering_recipe=self._get_xnnpack_lowering_recipe(),
95104
executorch_backend_config=get_xnnpack_executorch_backend_config(),
96-
partitioners=[XnnpackPartitioner()],
97105
)
98106

99107
def _build_quantized_recipe(
@@ -120,9 +128,8 @@ def _build_quantized_recipe(
120128
return ExportRecipe(
121129
name=recipe_type.value,
122130
quantization_recipe=quant_recipe,
123-
edge_compile_config=get_xnnpack_edge_compile_config(),
131+
lowering_recipe=self._get_xnnpack_lowering_recipe(precision_type),
124132
executorch_backend_config=get_xnnpack_executorch_backend_config(),
125-
partitioners=[XnnpackPartitioner(config_precision=precision_type)],
126133
)
127134

128135
def _build_int8da_intx_weight_recipe(
@@ -150,9 +157,8 @@ def _build_int8da_intx_weight_recipe(
150157
return ExportRecipe(
151158
name=recipe_type.value,
152159
quantization_recipe=quant_recipe,
153-
edge_compile_config=get_xnnpack_edge_compile_config(),
160+
lowering_recipe=self._get_xnnpack_lowering_recipe(),
154161
executorch_backend_config=get_xnnpack_executorch_backend_config(),
155-
partitioners=[XnnpackPartitioner()],
156162
)
157163

158164
def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None:

export/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
"""
1616

1717
from .export import export, ExportSession
18-
from .recipe import ExportRecipe, QuantizationRecipe, RecipeType
18+
from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe, RecipeType
1919
from .recipe_provider import BackendRecipeProvider
2020
from .recipe_registry import recipe_registry
21-
21+
from .stages import StageType
2222

2323
__all__ = [
24+
"StageType",
2425
"ExportRecipe",
26+
"LoweringRecipe",
2527
"QuantizationRecipe",
2628
"ExportSession",
2729
"export",

export/export.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from tabulate import tabulate
1717
from torch import nn
1818

19-
from .recipe import ExportRecipe, QuantizationRecipe
19+
from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe
2020
from .stages import (
2121
EdgeTransformAndLowerStage,
2222
ExecutorchStage,
@@ -145,6 +145,10 @@ def __init__(
145145
self._export_recipe.quantization_recipe
146146
)
147147

148+
self._lowering_recipe: Optional[LoweringRecipe] = (
149+
self._export_recipe.lowering_recipe
150+
)
151+
148152
# Default pipeline
149153
self._pipeline_stages = pipeline_stages or self._get_default_pipeline()
150154
self._pipeline = self._build_pipeline_from_stages(self._pipeline_stages)
@@ -200,18 +204,42 @@ def _build_pipeline_from_stages(self, stage_types: List[StageType]) -> List[Stag
200204
stage = TorchExportStage(pre_edge_passes)
201205
elif stage_type == StageType.TO_EDGE_TRANSFORM_AND_LOWER:
202206
stage = EdgeTransformAndLowerStage(
203-
partitioners=self._export_recipe.partitioners,
204-
transform_passes=self._export_recipe.edge_transform_passes,
205-
compile_config=self._export_recipe.edge_compile_config,
207+
partitioners=(
208+
self._lowering_recipe.partitioners
209+
if self._lowering_recipe
210+
else None
211+
),
212+
transform_passes=(
213+
self._lowering_recipe.edge_transform_passes
214+
if self._lowering_recipe
215+
else None
216+
),
217+
compile_config=(
218+
self._lowering_recipe.edge_compile_config
219+
if self._lowering_recipe
220+
else None
221+
),
206222
)
207223
elif stage_type == StageType.TO_EDGE:
208224
stage = ToEdgeStage(
209-
edge_compile_config=self._export_recipe.edge_compile_config
225+
edge_compile_config=(
226+
self._lowering_recipe.edge_compile_config
227+
if self._lowering_recipe
228+
else None
229+
),
210230
)
211231
elif stage_type == StageType.TO_BACKEND:
212232
stage = ToBackendStage(
213-
partitioners=self._export_recipe.partitioners,
214-
transform_passes=self._export_recipe.edge_transform_passes,
233+
partitioners=(
234+
self._lowering_recipe.partitioners
235+
if self._lowering_recipe
236+
else None
237+
),
238+
transform_passes=(
239+
self._lowering_recipe.edge_transform_passes
240+
if self._lowering_recipe
241+
else None
242+
),
215243
)
216244
elif stage_type == StageType.TO_EXECUTORCH:
217245
stage = ExecutorchStage(self._export_recipe.executorch_backend_config)

export/recipe.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,26 @@ def get_quantizers(self) -> Optional[List[Quantizer]]:
8787
return self.quantizers
8888

8989

90+
@dataclass
91+
class LoweringRecipe:
92+
"""
93+
Configuration recipe for lowering and partitioning.
94+
95+
This class holds the configuration parameters for lowering a model
96+
to backend-specific representations.
97+
98+
Attributes:
99+
partitioners: Optional list of partitioners for model partitioning
100+
edge_transform_passes: Optional sequence of transformation passes to apply
101+
edge_compile_config: Optional edge compilation configuration
102+
"""
103+
104+
partitioners: Optional[List[Partitioner]] = None
105+
edge_transform_passes: Optional[Sequence[PassType]] = None
106+
# pyre-ignore[11]: Type not defined
107+
edge_compile_config: Optional[EdgeCompileConfig] = None
108+
109+
90110
@experimental(
91111
"This API and all of its related functionality such as ExportSession and ExportRecipe are experimental."
92112
)
@@ -101,27 +121,19 @@ class ExportRecipe:
101121
Attributes:
102122
name: Optional name for the recipe
103123
quantization_recipe: Optional quantization recipe for model quantization
104-
edge_compile_config: Optional edge compilation configuration
124+
lowering_recipe: Optional lowering recipe for model lowering and partitioning
125+
executorch_backend_config: Optional backend configuration for ExecuTorch
105126
pre_edge_transform_passes: Optional function to apply transformation passes
106127
before edge lowering
107-
edge_transform_passes: Optional sequence of transformation passes to apply
108-
during edge lowering
109-
transform_check_ir_validity: Whether to check IR validity during transformation
110-
partitioners: Optional list of partitioners for model partitioning
111-
executorch_backend_config: Optional backend configuration for ExecuTorch
112128
mode: Export mode (debug or release)
113129
"""
114130

115131
name: Optional[str] = None
116132
quantization_recipe: Optional[QuantizationRecipe] = None
117-
# pyre-ignore[11]: Type not defined
118-
edge_compile_config: Optional[EdgeCompileConfig] = None
119-
pre_edge_transform_passes: Optional[Sequence[PassType]] = None
120-
edge_transform_passes: Optional[Sequence[PassType]] = None
121-
transform_check_ir_validity: bool = True
122-
partitioners: Optional[List[Partitioner]] = None
133+
lowering_recipe: Optional[LoweringRecipe] = None
123134
# pyre-ignore[11]: Type not defined
124135
executorch_backend_config: Optional[ExecutorchBackendConfig] = None
136+
pre_edge_transform_passes: Optional[Sequence[PassType]] = None
125137
mode: Mode = Mode.RELEASE
126138

127139
@classmethod

export/tests/test_export_session.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch
1313
from executorch.export import ExportRecipe, ExportSession
14+
from executorch.export.recipe import LoweringRecipe, QuantizationRecipe
1415
from executorch.export.stages import PipelineArtifact, StageType
1516

1617

@@ -302,3 +303,48 @@ def test_save_to_pte_invalid_name(self) -> None:
302303

303304
with self.assertRaises(AssertionError):
304305
session.save_to_pte(None) # pyre-ignore
306+
307+
308+
class TestExportSessionPipelineBuilding(unittest.TestCase):
309+
"""Test pipeline building and stage configuration."""
310+
311+
def setUp(self) -> None:
312+
self.model = SimpleTestModel()
313+
self.example_inputs = [(torch.randn(2, 10),)]
314+
315+
def test_pipeline_building_with_all_recipes(self) -> None:
316+
"""Test pipeline building with quantization and lowering recipes."""
317+
# Create comprehensive recipes
318+
quant_recipe = QuantizationRecipe(
319+
ao_base_config=[Mock()],
320+
quantizers=[Mock()],
321+
)
322+
lowering_recipe = LoweringRecipe(
323+
partitioners=[Mock()],
324+
edge_transform_passes=[Mock()],
325+
edge_compile_config=Mock(),
326+
)
327+
recipe = ExportRecipe(
328+
name="comprehensive_test",
329+
quantization_recipe=quant_recipe,
330+
lowering_recipe=lowering_recipe,
331+
pre_edge_transform_passes=[Mock()],
332+
executorch_backend_config=Mock(),
333+
)
334+
335+
session = ExportSession(
336+
model=self.model,
337+
example_inputs=self.example_inputs,
338+
export_recipe=recipe,
339+
)
340+
341+
self.assertEqual(len(session._pipeline), 5)
342+
stage_types = [stage.stage_type for stage in session._pipeline]
343+
expected_types = [
344+
StageType.SOURCE_TRANSFORM,
345+
StageType.QUANTIZE,
346+
StageType.TORCH_EXPORT,
347+
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
348+
StageType.TO_EXECUTORCH,
349+
]
350+
self.assertEqual(stage_types, expected_types)

0 commit comments

Comments
 (0)