diff --git a/backends/arm/_passes/arm_pass.py b/backends/arm/_passes/arm_pass.py index f893eba4fc9..444bf2be642 100644 --- a/backends/arm/_passes/arm_pass.py +++ b/backends/arm/_passes/arm_pass.py @@ -8,10 +8,10 @@ from abc import abstractmethod from typing import List, Optional, Set, Type -from executorch.exir.pass_base import ExportPass, NodeMetadata +from executorch.exir.pass_base import ExportPass, NodeMetadata, RequireExportedProgram -class ArmPass(ExportPass): +class ArmPass(RequireExportedProgram, ExportPass): """Base class for Arm passes""" @property diff --git a/backends/transforms/fuse_batch_norm_with_conv.py b/backends/transforms/fuse_batch_norm_with_conv.py index dda74a3dd6b..fceef010892 100644 --- a/backends/transforms/fuse_batch_norm_with_conv.py +++ b/backends/transforms/fuse_batch_norm_with_conv.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -11,12 +12,12 @@ from executorch.backends.transforms.utils import get_param_tensor, is_param_node from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.pass_base import ExportPass, PassResult, RequireExportedProgram from torch.nn.utils.fusion import fuse_conv_bn_weights -class FuseBatchNormWithConvPass(ExportPass): +class FuseBatchNormWithConvPass(RequireExportedProgram, ExportPass): """ Batch Norm can be implemented using 1x1 Depthwise Convolution. However doing so will increase memory usage since we serialize new weights to represent the convolution. In most cases, @@ -24,10 +25,6 @@ class FuseBatchNormWithConvPass(ExportPass): with the previous convolution """ - def __init__(self, exported_program: ExportedProgram): - super().__init__() - self.exported_program = exported_program - def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph counter = 0 diff --git a/backends/xnnpack/_passes/xnnpack_pass.py b/backends/xnnpack/_passes/xnnpack_pass.py index 47b7a9fe3d3..3e70b207f61 100644 --- a/backends/xnnpack/_passes/xnnpack_pass.py +++ b/backends/xnnpack/_passes/xnnpack_pass.py @@ -1,22 +1,16 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from executorch.exir.pass_base import ExportPass -from torch.export import ExportedProgram +from executorch.exir.pass_base import ExportPass, RequireExportedProgram -class XNNPACKPass(ExportPass): +class XNNPACKPass(RequireExportedProgram, ExportPass): """ An abstract interface for XNNPACK backend passes. """ - def __init__(self, exported_program: ExportedProgram) -> None: - super().__init__() - self._exported_program = exported_program - - @property - def exported_program(self) -> ExportedProgram: - return self._exported_program + ... diff --git a/exir/pass_base.py b/exir/pass_base.py index 497970fae34..0151aaebde0 100644 --- a/exir/pass_base.py +++ b/exir/pass_base.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -36,6 +37,7 @@ from torch._subclasses import FakeTensorMode, UnsupportedFakeTensorException from torch._subclasses.fake_tensor import FakeTensor from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode +from torch.export import ExportedProgram from torch.fx import traceback as fx_traceback from torch.fx.experimental.proxy_tensor import PythonKeyTracer from torch.fx.graph import CodeGen @@ -734,6 +736,24 @@ def migrate_meta_val( return res +class RequireExportedProgram: + """Mixin to require a pass to take an exported program, which is accessed by the exported_program property. + Note that the mixin needs to be added to the left of the pass class in the inheritance list to get a correct MRO. + """ + + def __init__(self, exported_program: ExportedProgram | None = None) -> None: + self._exported_program = exported_program + super().__init__() + + @property + def exported_program(self) -> ExportedProgram: + if self._exported_program is None: + raise ValueError( + "Tried to access exported_program, but it was not provided when constructing the pass." + ) + return self._exported_program + + @runtime_checkable class ArgSchema(Protocol): name: str diff --git a/export/export.py b/export/export.py index 1e9cdbde7c0..c375f51ccac 100644 --- a/export/export.py +++ b/export/export.py @@ -182,6 +182,7 @@ def _get_default_pipeline(self) -> List[StageType]: StageType.QUANTIZE, # Optional stage, returns original model if quant recipe is invalid StageType.TORCH_EXPORT, StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.POST_TO_BACKEND, StageType.TO_EXECUTORCH, ] diff --git a/export/recipe.py b/export/recipe.py index 4465da51956..544fa59b325 100644 --- a/export/recipe.py +++ b/export/recipe.py @@ -17,6 +17,7 @@ from executorch.exir.backend.partitioner import Partitioner from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig +from executorch.exir.pass_base import ExportPass from executorch.exir.pass_manager import PassType from torchao.core.config import AOBaseConfig from torchao.quantization.pt2e.quantizer import Quantizer @@ -122,6 +123,7 @@ class LoweringRecipe: edge_transform_passes: Optional list of callables that take (method_name: str, exported_program: ExportedProgram) as arguments and return a list of passes (PassType) to be executed during lowering stages. edge_compile_config: Optional edge compilation configuration + post_to_backend_passes: Optional list of passes to run after all partitioners have ran. """ partitioners: Optional[List[Partitioner]] = None @@ -130,6 +132,7 @@ class LoweringRecipe: ) = None # pyre-ignore[11]: Type not defined edge_compile_config: Optional[EdgeCompileConfig] = None + post_to_backend_passes: list[PassType | type[ExportPass]] | None = None @experimental( diff --git a/export/stages.py b/export/stages.py index 3be801c6a14..07f20a180a3 100644 --- a/export/stages.py +++ b/export/stages.py @@ -9,12 +9,14 @@ import logging from abc import ABC, abstractmethod from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, cast, Dict, List, Optional import torch from executorch.devtools.backend_debug import get_delegation_info -from executorch.exir import EdgeCompileConfig, ExportedProgram +from executorch.exir import EdgeCompileConfig, EdgeProgramManager, ExportedProgram from executorch.exir.backend.backend_api import validation_disabled +from executorch.exir.pass_base import ExportPass, RequireExportedProgram +from executorch.exir.pass_manager import PassManager from executorch.exir.program import to_edge, to_edge_transform_and_lower from executorch.export.recipe import LoweringRecipe, QuantizationRecipe from executorch.export.types import StageType @@ -118,7 +120,7 @@ def __init__( self.strict = strict @property - def stage_type(self) -> str: + def stage_type(self) -> StageType: return StageType.TORCH_EXPORT @property @@ -197,7 +199,7 @@ def from_recipe( ) @property - def stage_type(self) -> str: + def stage_type(self) -> StageType: return StageType.TO_EDGE_TRANSFORM_AND_LOWER @property @@ -266,7 +268,7 @@ def __init__(self, backend_config: Any) -> None: self._backend_config = backend_config @property - def stage_type(self) -> str: + def stage_type(self) -> StageType: return StageType.TO_EXECUTORCH @property @@ -304,7 +306,7 @@ def __init__(self, quantization_recipe: Optional[QuantizationRecipe]) -> None: self._transformed_models: Dict[str, nn.Module] = {} @property - def stage_type(self) -> str: + def stage_type(self) -> StageType: return StageType.SOURCE_TRANSFORM @property @@ -358,7 +360,7 @@ def __init__(self, quantization_recipe: Optional[QuantizationRecipe]) -> None: self._quantization_recipe = quantization_recipe @property - def stage_type(self) -> str: + def stage_type(self) -> StageType: return StageType.QUANTIZE @property @@ -459,7 +461,7 @@ def from_recipe(cls, lowering_recipe: Optional["LoweringRecipe"]) -> "ToEdgeStag ) @property - def stage_type(self) -> str: + def stage_type(self) -> StageType: return StageType.TO_EDGE @property @@ -520,7 +522,7 @@ def from_recipe( ) @property - def stage_type(self) -> str: + def stage_type(self) -> StageType: return StageType.TO_BACKEND @property @@ -583,3 +585,102 @@ def delegation_info(self) -> Any: Returns the delegation info. """ return self._artifact.get_context("delegation_info") + + +class PostToBackendStage(Stage): + """ + Stage: Run passes after all partitioners have done their partitioning. + """ + + def __init__( + self, + pass_list_or_manager: ( + list[PassType | type[ExportPass]] | PassManager | None + ) = None, + edge_compile_config: EdgeCompileConfig | None = None, + ) -> None: + super().__init__() + if pass_list_or_manager is None: + pass_list_or_manager = [] + + self._pass_list_or_manager = pass_list_or_manager + self._edge_compile_config = edge_compile_config + + @classmethod + def from_recipe( + cls, lowering_recipe: Optional["LoweringRecipe"] + ) -> "PostToBackendStage": + if lowering_recipe is None: + return cls() + + return cls( + pass_list=lowering_recipe.post_to_backend_passes, + edge_compile_config=lowering_recipe.edge_compile_config, + ) + + @property + def stage_type(self) -> StageType: + return StageType.POST_TO_BACKEND + + @property + def valid_predecessor_stages(self) -> List["StageType"]: + return [StageType.TO_BACKEND, StageType.TO_EDGE_TRANSFORM_AND_LOWER] + + @property + def can_start_pipeline(self) -> bool: + return False + + def run(self, artifact: PipelineArtifact) -> None: + """ + Run list of passes using edge_program_manager.transform(). + + Args: + artifact: PipelineArtifact which's data field is expected to contain an edge_program_manager. + """ + + if self._pass_list_or_manager: + edge_program_manager = cast(EdgeProgramManager, artifact.data) + + if isinstance(self._pass_list_or_manager, PassManager): + edge_program_manager = edge_program_manager.transform( + self._pass_list_or_manager, self._edge_compile_config + ) + else: + exported_program = edge_program_manager.exported_program() + pass_instances: list[PassType] = [] + for _pass in self._pass_list_or_manager: + if isinstance(_pass, type): + if not issubclass(_pass, ExportPass): + raise RuntimeError( + f"Pass {_pass} was not subclass of ExportPass." + ) + if issubclass(_pass, RequireExportedProgram): + pass_instance = _pass( + exported_program=exported_program # type: ignore + ) + else: + pass_instance = _pass() + pass_instances.append(pass_instance) + else: + pass_instances.append(_pass) + + edge_program_manager = edge_program_manager.transform( + pass_instances, self._edge_compile_config + ) + # Get delegation info + delegation_info = get_delegation_info( + edge_program_manager.exported_program().graph_module + ) + + self._artifact = artifact.copy_with_new_data(edge_program_manager) + self._artifact.add_context("delegation_info", delegation_info) + else: + # If pass_list_or_manager is None or empty list, do nothing. + self._artifact = artifact + + @property + def delegation_info(self) -> Any: + """ + Returns the delegation info. + """ + return self._artifact.get_context("delegation_info") diff --git a/export/tests/test_export_stages.py b/export/tests/test_export_stages.py index 4e8144bd487..b3551bbb393 100644 --- a/export/tests/test_export_stages.py +++ b/export/tests/test_export_stages.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -10,12 +11,14 @@ from unittest.mock import Mock, patch, PropertyMock import torch +from executorch.exir.pass_base import ExportPass, RequireExportedProgram from executorch.exir.program import EdgeProgramManager, ExecutorchProgramManager from executorch.export import AOQuantizationConfig, QuantizationRecipe, StageType from executorch.export.stages import ( EdgeTransformAndLowerStage, ExecutorchStage, PipelineArtifact, + PostToBackendStage, QuantizeStage, SourceTransformStage, ToBackendStage, @@ -35,6 +38,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) +class DummyExportPassWithProgram(RequireExportedProgram, ExportPass): + def __init__(self, exported_program: ExportedProgram) -> None: + super().__init__() + self.received_exported_program = exported_program + + +class DummyExportPassWithoutProgram(ExportPass): + def __init__(self) -> None: + super().__init__() + self.initialized = True + + class TestPipelineArtifact(unittest.TestCase): def test_copy_with_new_data(self) -> None: @@ -587,3 +602,51 @@ def test_run_edge_manager_none(self) -> None: with self.assertRaises(RuntimeError) as cm: stage.run(artifact) self.assertIn("Edge program manager is not set", str(cm.exception)) + + +class TestPostToBackendStage(unittest.TestCase): + @patch("executorch.export.stages.get_delegation_info") + def test_run_with_mixed_pass_types(self, mock_get_delegation_info: Mock) -> None: + mock_get_delegation_info.return_value = {"delegation": "info"} + + exported_program = Mock(spec=ExportedProgram) + edge_program_manager = Mock(spec=EdgeProgramManager) + transformed_manager = Mock(spec=EdgeProgramManager) + transformed_exported_program = Mock(spec=ExportedProgram) + transformed_graph_module = Mock() + transformed_exported_program.graph_module = transformed_graph_module + + edge_program_manager.exported_program.return_value = exported_program + edge_program_manager.transform.return_value = transformed_manager + transformed_manager.exported_program.return_value = transformed_exported_program + + passthrough_pass = Mock() + non_program_pass_instance = DummyExportPassWithoutProgram() + + stage = PostToBackendStage( + pass_list_or_manager=[ + passthrough_pass, + DummyExportPassWithProgram, + DummyExportPassWithoutProgram, + non_program_pass_instance, + ] + ) + + artifact = PipelineArtifact(data=edge_program_manager, context={}) + stage.run(artifact) + + edge_program_manager.transform.assert_called_once() + pass_instances, compile_config = edge_program_manager.transform.call_args[0] + + self.assertEqual(len(pass_instances), 4) + self.assertIs(pass_instances[0], passthrough_pass) + self.assertIsInstance(pass_instances[1], DummyExportPassWithProgram) + self.assertIsInstance(pass_instances[2], DummyExportPassWithoutProgram) + self.assertIs(pass_instances[1].received_exported_program, exported_program) + self.assertIs(pass_instances[3], non_program_pass_instance) + self.assertIsNone(compile_config) + + result_artifact = stage.get_artifacts() + self.assertIs(result_artifact.data, transformed_manager) + self.assertEqual(stage.delegation_info, {"delegation": "info"}) + mock_get_delegation_info.assert_called_once_with(transformed_graph_module) diff --git a/export/types.py b/export/types.py index 760f8461d41..89b46608ffc 100644 --- a/export/types.py +++ b/export/types.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -18,4 +19,5 @@ class StageType(str, Enum): TO_EDGE_TRANSFORM_AND_LOWER = "to_edge_transform_and_lower" TO_EDGE = "to_edge" TO_BACKEND = "to_backend" + POST_TO_BACKEND = "post_to_backend" TO_EXECUTORCH = "to_executorch"