From 36daa0d8188e3a3f842f631d0e7d996a8a0f83a0 Mon Sep 17 00:00:00 2001 From: Abhinay Kukkadapu Date: Mon, 28 Jul 2025 16:39:28 -0700 Subject: [PATCH] [Executorch][Export][2/N] Add to_edge and to_backend stages Address (6) in the rfc: https://github.com/pytorch/executorch/issues/12660 1. Adds stage implementations for `to_edge` and `to_backend` 2. Adds unit tests for the two stages 3. Adds these two stages in the validation pipeline. Fixes #12932 Differential Revision: [D79120576](https://our.internmc.facebook.com/intern/diff/D79120576/) [ghstack-poisoned] --- export/export.py | 14 ++++ export/stages.py | 100 ++++++++++++++++++++++++++- export/tests/test_export_stages.py | 104 +++++++++++++++++++++++++++++ 3 files changed, 217 insertions(+), 1 deletion(-) diff --git a/export/export.py b/export/export.py index 947121ac2f3..ee6ef0ef002 100644 --- a/export/export.py +++ b/export/export.py @@ -25,6 +25,8 @@ SourceTransformStage, Stage, StageType, + ToBackendStage, + ToEdgeStage, TorchExportStage, ) @@ -151,9 +153,12 @@ def __init__( StageType.SOURCE_TRANSFORM: [StageType.QUANTIZE, StageType.TORCH_EXPORT], StageType.QUANTIZE: [StageType.TORCH_EXPORT], StageType.TORCH_EXPORT: [ + StageType.TO_EDGE, StageType.TO_EDGE_TRANSFORM_AND_LOWER, ], StageType.TO_EDGE_TRANSFORM_AND_LOWER: [StageType.TO_EXECUTORCH], + StageType.TO_EDGE: [StageType.TO_BACKEND], + StageType.TO_BACKEND: [StageType.TO_EXECUTORCH], StageType.TO_EXECUTORCH: [], } @@ -199,6 +204,15 @@ def _build_pipeline_from_stages(self, stage_types: List[StageType]) -> List[Stag transform_passes=self._export_recipe.edge_transform_passes, compile_config=self._export_recipe.edge_compile_config, ) + elif stage_type == StageType.TO_EDGE: + stage = ToEdgeStage( + edge_compile_config=self._export_recipe.edge_compile_config + ) + elif stage_type == StageType.TO_BACKEND: + stage = ToBackendStage( + partitioners=self._export_recipe.partitioners, + transform_passes=self._export_recipe.edge_transform_passes, + ) elif stage_type == StageType.TO_EXECUTORCH: stage = ExecutorchStage(self._export_recipe.executorch_backend_config) else: diff --git a/export/stages.py b/export/stages.py index c17899fe29b..d7228bb8bcf 100644 --- a/export/stages.py +++ b/export/stages.py @@ -11,8 +11,9 @@ import torch from executorch.devtools.backend_debug import get_delegation_info +from executorch.exir import EdgeCompileConfig from executorch.exir.backend.backend_api import validation_disabled -from executorch.exir.program import to_edge_transform_and_lower +from executorch.exir.program import to_edge, to_edge_transform_and_lower from executorch.exir.program._program import _transform from executorch.export.recipe import QuantizationRecipe from torch import nn @@ -29,6 +30,8 @@ class StageType(str, Enum): TORCH_EXPORT = "torch_export" TO_EDGE_TRANSFORM_AND_LOWER = "to_edge_transform_and_lower" TO_EXECUTORCH = "to_executorch" + TO_EDGE = "to_edge" + TO_BACKEND = "to_backend" class PipelineArtifact: @@ -306,3 +309,98 @@ def run(self, artifact: PipelineArtifact) -> None: quantized_models[method_name] = quantized_model self._artifact = artifact.copy_with_new_data(quantized_models) + + +class ToEdgeStage(Stage): + """ + Stage: Convert ExportedProgram to EdgeProgramManager. + """ + + def __init__( + self, + edge_compile_config: Optional[EdgeCompileConfig] = None, # pyre-ignore + ) -> None: + super().__init__() + self._edge_compile_config = edge_compile_config + + @property + def stage_type(self) -> str: + return StageType.TO_EDGE + + def run(self, artifact: PipelineArtifact) -> None: + """ + Convert ExportedProgram to EdgeProgramManager. + + Args: + artifact: Contains exported programs and context + """ + exported_programs = artifact.data + constant_methods = artifact.get_context("constant_methods") + + # Convert to edge program manager + edge_program_manager = to_edge( + exported_programs, + constant_methods=constant_methods, + compile_config=self._edge_compile_config, + ) + + self._artifact = artifact.copy_with_new_data(edge_program_manager) + + +class ToBackendStage(Stage): + """ + Stage: Apply transformations and partitioning to EdgeProgramManager. + """ + + def __init__( + self, + partitioners: Optional[List[Any]] = None, + transform_passes: Optional[Sequence[Callable[[Any], Optional[Any]]]] = None, + ) -> None: + super().__init__() + self._partitioners = partitioners + self._transform_passes = transform_passes + + @property + def stage_type(self) -> str: + return StageType.TO_BACKEND + + def run(self, artifact: PipelineArtifact) -> None: + """ + Apply transformations and partitioning to EdgeProgramManager. + + Args: + artifact: Contains edge program manager and context + """ + edge_program_manager = artifact.data + + if edge_program_manager is None: + raise RuntimeError("Edge program manager is not set.") + + # Apply transform passes if available + if self._transform_passes: + edge_program_manager = edge_program_manager.transform( + self._transform_passes + ) + + # Apply partitioners if available + if self._partitioners is not None and len(self._partitioners) > 0: + with validation_disabled(): + # pyre-ignore + for partitioner in self._partitioners: + edge_program_manager = edge_program_manager.to_backend(partitioner) + + # Get delegation info + delegation_info = get_delegation_info( + edge_program_manager.exported_program().graph_module + ) + + self._artifact = artifact.copy_with_new_data(edge_program_manager) + self._artifact.add_context("delegation_info", delegation_info) + + @property + def delegation_info(self) -> Any: + """ + Returns the delegation info. + """ + return self._artifact.get_context("delegation_info") diff --git a/export/tests/test_export_stages.py b/export/tests/test_export_stages.py index 5d83b4f9046..2b3e533723a 100644 --- a/export/tests/test_export_stages.py +++ b/export/tests/test_export_stages.py @@ -19,6 +19,8 @@ QuantizeStage, SourceTransformStage, StageType, + ToBackendStage, + ToEdgeStage, TorchExportStage, ) from torch.export import ExportedProgram @@ -282,3 +284,105 @@ def test_run_empty_example_inputs(self) -> None: self.assertIn( "Example inputs for method forward not found or empty", str(cm.exception) ) + + +class TestToEdgeStage(unittest.TestCase): + def setUp(self) -> None: + self.mock_exported_program = Mock(spec=ExportedProgram) + self.exported_programs = {"forward": self.mock_exported_program} + self.context = {"constant_methods": None} + + @patch("executorch.export.stages.to_edge") + def test_run_success(self, mock_to_edge: Mock) -> None: + mock_edge_manager = Mock(spec=EdgeProgramManager) + mock_to_edge.return_value = mock_edge_manager + mock_config = Mock() + + stage = ToEdgeStage(edge_compile_config=mock_config) + artifact = PipelineArtifact(data=self.exported_programs, context=self.context) + stage.run(artifact) + + # Verify to_edge was called with correct parameters + mock_to_edge.assert_called_once_with( + self.exported_programs, + constant_methods=None, + compile_config=mock_config, + ) + + # Verify artifacts are set correctly + result_artifact = stage.get_artifacts() + self.assertEqual(result_artifact.data, mock_edge_manager) + + +class TestToBackendStage(unittest.TestCase): + def setUp(self) -> None: + self.mock_edge_manager = Mock(spec=EdgeProgramManager) + self.context = {} + + @patch("executorch.export.stages.get_delegation_info") + def test_run_success_no_transforms_or_partitioners( + self, mock_get_delegation_info: Mock + ) -> None: + # Test successful execution without transforms or partitioners + mock_delegation_info = {"delegation": "info"} + mock_get_delegation_info.return_value = mock_delegation_info + mock_exported_program = Mock() + mock_graph_module = Mock() + mock_exported_program.graph_module = mock_graph_module + self.mock_edge_manager.exported_program.return_value = mock_exported_program + + stage = ToBackendStage() + artifact = PipelineArtifact(data=self.mock_edge_manager, context=self.context) + stage.run(artifact) + + # Verify get_delegation_info was called + mock_get_delegation_info.assert_called_once_with(mock_graph_module) + + # Verify artifacts are set correctly + result_artifact = stage.get_artifacts() + self.assertEqual(result_artifact.data, self.mock_edge_manager) + self.assertEqual( + result_artifact.get_context("delegation_info"), mock_delegation_info + ) + + @patch("executorch.export.stages.get_delegation_info") + def test_run_with_partitioners_and_passes( + self, mock_get_delegation_info: Mock + ) -> None: + mock_delegation_info = {"delegation": "info"} + mock_get_delegation_info.return_value = mock_delegation_info + mock_exported_program = Mock() + mock_graph_module = Mock() + mock_exported_program.graph_module = mock_graph_module + + mock_edge_program_manager = Mock(spec=EdgeProgramManager) + mock_edge_program_manager.transform.return_value = mock_edge_program_manager + mock_edge_program_manager.to_backend.return_value = mock_edge_program_manager + + mock_partitioner = Mock() + mock_transform_passes = [Mock(), Mock()] + stage = ToBackendStage( + partitioners=[mock_partitioner], transform_passes=mock_transform_passes + ) + artifact = PipelineArtifact( + data=mock_edge_program_manager, context=self.context + ) + stage.run(artifact) + + # Verify transform and to_backend called correctly + mock_edge_program_manager.transform.assert_called_once_with( + mock_transform_passes + ) + mock_edge_program_manager.to_backend.assert_called_once_with(mock_partitioner) + + # Verify artifacts contain the backend manager + result_artifact = stage.get_artifacts() + self.assertEqual(result_artifact.data, mock_edge_program_manager) + + def test_run_edge_manager_none(self) -> None: + stage = ToBackendStage() + artifact = PipelineArtifact(data=None, context=self.context) + + with self.assertRaises(RuntimeError) as cm: + stage.run(artifact) + self.assertIn("Edge program manager is not set", str(cm.exception))