diff --git a/backends/xnnpack/recipes/xnnpack_recipe_provider.py b/backends/xnnpack/recipes/xnnpack_recipe_provider.py index 9d00c3c9c98..8fba58c12c3 100644 --- a/backends/xnnpack/recipes/xnnpack_recipe_provider.py +++ b/backends/xnnpack/recipes/xnnpack_recipe_provider.py @@ -27,6 +27,7 @@ from executorch.export import ( BackendRecipeProvider, ExportRecipe, + LoweringRecipe, QuantizationRecipe, RecipeType, ) @@ -88,12 +89,19 @@ def create_recipe( ) return None + def _get_xnnpack_lowering_recipe( + self, precision_type: Optional[ConfigPrecisionType] = None + ) -> LoweringRecipe: + return LoweringRecipe( + partitioners=[XnnpackPartitioner(precision_type=precision_type)], + edge_compile_config=get_xnnpack_edge_compile_config(), + ) + def _build_fp32_recipe(self, recipe_type: RecipeType) -> ExportRecipe: return ExportRecipe( name=recipe_type.value, - edge_compile_config=get_xnnpack_edge_compile_config(), + lowering_recipe=self._get_xnnpack_lowering_recipe(), executorch_backend_config=get_xnnpack_executorch_backend_config(), - partitioners=[XnnpackPartitioner()], ) def _build_quantized_recipe( @@ -120,9 +128,8 @@ def _build_quantized_recipe( return ExportRecipe( name=recipe_type.value, quantization_recipe=quant_recipe, - edge_compile_config=get_xnnpack_edge_compile_config(), + lowering_recipe=self._get_xnnpack_lowering_recipe(precision_type), executorch_backend_config=get_xnnpack_executorch_backend_config(), - partitioners=[XnnpackPartitioner(config_precision=precision_type)], ) def _build_int8da_intx_weight_recipe( @@ -150,9 +157,8 @@ def _build_int8da_intx_weight_recipe( return ExportRecipe( name=recipe_type.value, quantization_recipe=quant_recipe, - edge_compile_config=get_xnnpack_edge_compile_config(), + lowering_recipe=self._get_xnnpack_lowering_recipe(), executorch_backend_config=get_xnnpack_executorch_backend_config(), - partitioners=[XnnpackPartitioner()], ) def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None: diff --git a/export/__init__.py b/export/__init__.py index 2ee5026d320..d5f3826ab90 100644 --- a/export/__init__.py +++ b/export/__init__.py @@ -15,7 +15,7 @@ """ from .export import export, ExportSession -from .recipe import ExportRecipe, QuantizationRecipe, RecipeType +from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe, RecipeType from .recipe_provider import BackendRecipeProvider from .recipe_registry import recipe_registry from .types import StageType @@ -23,6 +23,7 @@ __all__ = [ "StageType", "ExportRecipe", + "LoweringRecipe", "QuantizationRecipe", "ExportSession", "export", diff --git a/export/export.py b/export/export.py index ac9d894fea1..e5c3b793ccd 100644 --- a/export/export.py +++ b/export/export.py @@ -16,7 +16,7 @@ from tabulate import tabulate from torch import nn -from .recipe import ExportRecipe, QuantizationRecipe +from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe from .stages import ( EdgeTransformAndLowerStage, ExecutorchStage, @@ -143,6 +143,10 @@ def __init__( self._export_recipe.quantization_recipe ) + self._lowering_recipe: Optional[LoweringRecipe] = ( + self._export_recipe.lowering_recipe + ) + # Stages to run self._pipeline_stages = ( self._export_recipe.pipeline_stages or self._get_default_pipeline() @@ -192,20 +196,11 @@ def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]: ) stage = TorchExportStage(pre_edge_passes) elif stage_type == StageType.TO_EDGE_TRANSFORM_AND_LOWER: - stage = EdgeTransformAndLowerStage( - partitioners=self._export_recipe.partitioners, - transform_passes=self._export_recipe.edge_transform_passes, - compile_config=self._export_recipe.edge_compile_config, - ) + stage = EdgeTransformAndLowerStage.from_recipe(self._lowering_recipe) elif stage_type == StageType.TO_EDGE: - stage = ToEdgeStage( - edge_compile_config=self._export_recipe.edge_compile_config - ) + stage = ToEdgeStage.from_recipe(self._lowering_recipe) elif stage_type == StageType.TO_BACKEND: - stage = ToBackendStage( - partitioners=self._export_recipe.partitioners, - transform_passes=self._export_recipe.edge_transform_passes, - ) + stage = ToBackendStage.from_recipe(self._lowering_recipe) elif stage_type == StageType.TO_EXECUTORCH: stage = ExecutorchStage(self._export_recipe.executorch_backend_config) else: diff --git a/export/recipe.py b/export/recipe.py index 315404c54af..8f7251cd419 100644 --- a/export/recipe.py +++ b/export/recipe.py @@ -89,6 +89,26 @@ def get_quantizers(self) -> Optional[List[Quantizer]]: return self.quantizers +@dataclass +class LoweringRecipe: + """ + Configuration recipe for lowering and partitioning. + + This class holds the configuration parameters for lowering a model + to backend-specific representations. + + Attributes: + partitioners: Optional list of partitioners for model partitioning + edge_transform_passes: Optional sequence of transformation passes to apply + edge_compile_config: Optional edge compilation configuration + """ + + partitioners: Optional[List[Partitioner]] = None + edge_transform_passes: Optional[Sequence[PassType]] = None + # pyre-ignore[11]: Type not defined + edge_compile_config: Optional[EdgeCompileConfig] = None + + @experimental( "This API and all of its related functionality such as ExportSession and ExportRecipe are experimental." ) @@ -103,13 +123,9 @@ class ExportRecipe: Attributes: name: Optional name for the recipe quantization_recipe: Optional quantization recipe for model quantization - edge_compile_config: Optional edge compilation configuration pre_edge_transform_passes: Optional function to apply transformation passes before edge lowering - edge_transform_passes: Optional sequence of transformation passes to apply - during edge lowering - transform_check_ir_validity: Whether to check IR validity during transformation - partitioners: Optional list of partitioners for model partitioning + lowering_recipe: Optional lowering recipe for model lowering and partitioning executorch_backend_config: Optional backend configuration for ExecuTorch pipeline_stages: Optional list of stages to execute, defaults to a standard pipeline. mode: Export mode (debug or release) @@ -117,12 +133,8 @@ class ExportRecipe: name: Optional[str] = None quantization_recipe: Optional[QuantizationRecipe] = None - # pyre-ignore[11]: Type not defined - edge_compile_config: Optional[EdgeCompileConfig] = None pre_edge_transform_passes: Optional[Sequence[PassType]] = None - edge_transform_passes: Optional[Sequence[PassType]] = None - transform_check_ir_validity: bool = True - partitioners: Optional[List[Partitioner]] = None + lowering_recipe: Optional[LoweringRecipe] = None # pyre-ignore[11]: Type not defined executorch_backend_config: Optional[ExecutorchBackendConfig] = None pipeline_stages: Optional[List[StageType]] = None diff --git a/export/stages.py b/export/stages.py index fd27c298028..dd22155e929 100644 --- a/export/stages.py +++ b/export/stages.py @@ -14,7 +14,7 @@ from executorch.exir.backend.backend_api import validation_disabled from executorch.exir.program import to_edge, to_edge_transform_and_lower from executorch.exir.program._program import _transform -from executorch.export.recipe import QuantizationRecipe +from executorch.export.recipe import LoweringRecipe, QuantizationRecipe from executorch.export.types import StageType from torch import nn from torch._export.pass_base import PassType @@ -168,6 +168,19 @@ def __init__( self._transform_passes = transform_passes self._compile_config = compile_config + @classmethod + def from_recipe( + cls, lowering_recipe: Optional["LoweringRecipe"] + ) -> "EdgeTransformAndLowerStage": + if lowering_recipe is None: + return cls() + + return cls( + partitioners=lowering_recipe.partitioners, + transform_passes=lowering_recipe.edge_transform_passes, + compile_config=lowering_recipe.edge_compile_config, + ) + @property def stage_type(self) -> str: return StageType.TO_EDGE_TRANSFORM_AND_LOWER @@ -369,6 +382,15 @@ def __init__( super().__init__() self._edge_compile_config = edge_compile_config + @classmethod + def from_recipe(cls, lowering_recipe: Optional["LoweringRecipe"]) -> "ToEdgeStage": + if lowering_recipe is None: + return cls() + + return cls( + edge_compile_config=lowering_recipe.edge_compile_config, + ) + @property def stage_type(self) -> str: return StageType.TO_EDGE @@ -415,6 +437,18 @@ def __init__( self._partitioners = partitioners self._transform_passes = transform_passes + @classmethod + def from_recipe( + cls, lowering_recipe: Optional["LoweringRecipe"] + ) -> "ToBackendStage": + if lowering_recipe is None: + return cls() + + return cls( + partitioners=lowering_recipe.partitioners, + transform_passes=lowering_recipe.edge_transform_passes, + ) + @property def stage_type(self) -> str: return StageType.TO_BACKEND diff --git a/export/tests/test_export_session.py b/export/tests/test_export_session.py index 7bef0d01876..92aeebb7304 100644 --- a/export/tests/test_export_session.py +++ b/export/tests/test_export_session.py @@ -12,6 +12,7 @@ import torch from executorch.export import ExportRecipe, ExportSession +from executorch.export.recipe import LoweringRecipe, QuantizationRecipe from executorch.export.stages import PipelineArtifact from executorch.export.types import StageType @@ -434,3 +435,48 @@ def test_save_to_pte_invalid_name(self) -> None: with self.assertRaises(AssertionError): session.save_to_pte(None) # pyre-ignore + + +class TestExportSessionPipelineBuilding(unittest.TestCase): + """Test pipeline building and stage configuration.""" + + def setUp(self) -> None: + self.model = SimpleTestModel() + self.example_inputs = [(torch.randn(2, 10),)] + + def test_pipeline_building_with_all_recipes(self) -> None: + """Test pipeline building with quantization and lowering recipes.""" + # Create comprehensive recipes + quant_recipe = QuantizationRecipe( + ao_base_config=[Mock()], + quantizers=[Mock()], + ) + lowering_recipe = LoweringRecipe( + partitioners=[Mock()], + edge_transform_passes=[Mock()], + edge_compile_config=Mock(), + ) + recipe = ExportRecipe( + name="comprehensive_test", + quantization_recipe=quant_recipe, + lowering_recipe=lowering_recipe, + executorch_backend_config=Mock(), + ) + + session = ExportSession( + model=self.model, + example_inputs=self.example_inputs, + export_recipe=recipe, + ) + + registered_stages = session.get_all_registered_stages() + + self.assertEqual(len(registered_stages), 5) + expected_types = [ + StageType.SOURCE_TRANSFORM, + StageType.QUANTIZE, + StageType.TORCH_EXPORT, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, + ] + self.assertListEqual(list(registered_stages.keys()), expected_types)