Skip to content

Commit 7cf569a

Browse files
Update on "[Executorch][Export][3/N] Modularize export recipes"
Addresses (7) in the rfc: #12660 Changes: 1. Add data class called `LoweringRecipe` 2. Modify current xnnpack recipes to use lowering recipes Fixes: #12933 Differential Revision: [D79120575](https://our.internmc.facebook.com/intern/diff/D79120575/) [ghstack-poisoned]
2 parents f4b7aa5 + 1bd26a5 commit 7cf569a

File tree

2 files changed

+38
-37
lines changed

2 files changed

+38
-37
lines changed

export/export.py

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -196,44 +196,11 @@ def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]:
196196
)
197197
stage = TorchExportStage(pre_edge_passes)
198198
elif stage_type == StageType.TO_EDGE_TRANSFORM_AND_LOWER:
199-
stage = EdgeTransformAndLowerStage(
200-
partitioners=(
201-
self._lowering_recipe.partitioners
202-
if self._lowering_recipe
203-
else None
204-
),
205-
transform_passes=(
206-
self._lowering_recipe.edge_transform_passes
207-
if self._lowering_recipe
208-
else None
209-
),
210-
compile_config=(
211-
self._lowering_recipe.edge_compile_config
212-
if self._lowering_recipe
213-
else None
214-
),
215-
)
199+
stage = EdgeTransformAndLowerStage.from_recipe(self._lowering_recipe)
216200
elif stage_type == StageType.TO_EDGE:
217-
stage = ToEdgeStage(
218-
edge_compile_config=(
219-
self._lowering_recipe.edge_compile_config
220-
if self._lowering_recipe
221-
else None
222-
),
223-
)
201+
stage = ToEdgeStage.from_recipe(self._lowering_recipe)
224202
elif stage_type == StageType.TO_BACKEND:
225-
stage = ToBackendStage(
226-
partitioners=(
227-
self._lowering_recipe.partitioners
228-
if self._lowering_recipe
229-
else None
230-
),
231-
transform_passes=(
232-
self._lowering_recipe.edge_transform_passes
233-
if self._lowering_recipe
234-
else None
235-
),
236-
)
203+
stage = ToBackendStage.from_recipe(self._lowering_recipe)
237204
elif stage_type == StageType.TO_EXECUTORCH:
238205
stage = ExecutorchStage(self._export_recipe.executorch_backend_config)
239206
else:

export/stages.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from executorch.exir.backend.backend_api import validation_disabled
1515
from executorch.exir.program import to_edge, to_edge_transform_and_lower
1616
from executorch.exir.program._program import _transform
17-
from executorch.export.recipe import QuantizationRecipe
17+
from executorch.export.recipe import LoweringRecipe, QuantizationRecipe
1818
from executorch.export.types import StageType
1919
from torch import nn
2020
from torch._export.pass_base import PassType
@@ -168,6 +168,19 @@ def __init__(
168168
self._transform_passes = transform_passes
169169
self._compile_config = compile_config
170170

171+
@classmethod
172+
def from_recipe(
173+
cls, lowering_recipe: Optional["LoweringRecipe"]
174+
) -> "EdgeTransformAndLowerStage":
175+
if lowering_recipe is None:
176+
return cls()
177+
178+
return cls(
179+
partitioners=lowering_recipe.partitioners,
180+
transform_passes=lowering_recipe.edge_transform_passes,
181+
compile_config=lowering_recipe.edge_compile_config,
182+
)
183+
171184
@property
172185
def stage_type(self) -> str:
173186
return StageType.TO_EDGE_TRANSFORM_AND_LOWER
@@ -369,6 +382,15 @@ def __init__(
369382
super().__init__()
370383
self._edge_compile_config = edge_compile_config
371384

385+
@classmethod
386+
def from_recipe(cls, lowering_recipe: Optional["LoweringRecipe"]) -> "ToEdgeStage":
387+
if lowering_recipe is None:
388+
return cls()
389+
390+
return cls(
391+
edge_compile_config=lowering_recipe.edge_compile_config,
392+
)
393+
372394
@property
373395
def stage_type(self) -> str:
374396
return StageType.TO_EDGE
@@ -415,6 +437,18 @@ def __init__(
415437
self._partitioners = partitioners
416438
self._transform_passes = transform_passes
417439

440+
@classmethod
441+
def from_recipe(
442+
cls, lowering_recipe: Optional["LoweringRecipe"]
443+
) -> "ToBackendStage":
444+
if lowering_recipe is None:
445+
return cls()
446+
447+
return cls(
448+
partitioners=lowering_recipe.partitioners,
449+
transform_passes=lowering_recipe.edge_transform_passes,
450+
)
451+
418452
@property
419453
def stage_type(self) -> str:
420454
return StageType.TO_BACKEND

0 commit comments

Comments
 (0)