From f9f0c69166dab407831cc0fde2eddad739df37ca Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Tue, 7 Oct 2025 10:13:59 +0200 Subject: [PATCH 1/2] Add RequireExportProgram mixin for passes It is a common pattern that passes require an exported program, which means that the pass needs to be constructed right before it is ran, with the correct exported program. Today there is no unified way this is done, which makes it tricky for general pass handling to figure out whether to pass an exported program to the constructor or not. This commit solves this by introducing a RequireExportProgram mixin, and makes a best effort to migirate passes to this system. Note that there are passes which require other arguments that still will break general pass handling. Signed-off-by: Erik Lundell Change-Id: Ifaef3069fb04d6ab3df68b1d303aae7dc42585f2 --- backends/arm/_passes/arm_pass.py | 4 ++-- .../transforms/fuse_batch_norm_with_conv.py | 9 +++------ backends/xnnpack/_passes/xnnpack_pass.py | 14 ++++--------- exir/pass_base.py | 20 +++++++++++++++++++ 4 files changed, 29 insertions(+), 18 deletions(-) diff --git a/backends/arm/_passes/arm_pass.py b/backends/arm/_passes/arm_pass.py index 3cc5e3ee0c0..8ec7ee38cc0 100644 --- a/backends/arm/_passes/arm_pass.py +++ b/backends/arm/_passes/arm_pass.py @@ -9,10 +9,10 @@ from abc import abstractmethod from typing import List, Optional, Set, Type -from executorch.exir.pass_base import ExportPass, NodeMetadata +from executorch.exir.pass_base import ExportPass, NodeMetadata, RequireExportedProgram -class ArmPass(ExportPass): +class ArmPass(RequireExportedProgram, ExportPass): """Base class for Arm passes""" @property diff --git a/backends/transforms/fuse_batch_norm_with_conv.py b/backends/transforms/fuse_batch_norm_with_conv.py index dda74a3dd6b..fceef010892 100644 --- a/backends/transforms/fuse_batch_norm_with_conv.py +++ b/backends/transforms/fuse_batch_norm_with_conv.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -11,12 +12,12 @@ from executorch.backends.transforms.utils import get_param_tensor, is_param_node from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.pass_base import ExportPass, PassResult, RequireExportedProgram from torch.nn.utils.fusion import fuse_conv_bn_weights -class FuseBatchNormWithConvPass(ExportPass): +class FuseBatchNormWithConvPass(RequireExportedProgram, ExportPass): """ Batch Norm can be implemented using 1x1 Depthwise Convolution. However doing so will increase memory usage since we serialize new weights to represent the convolution. In most cases, @@ -24,10 +25,6 @@ class FuseBatchNormWithConvPass(ExportPass): with the previous convolution """ - def __init__(self, exported_program: ExportedProgram): - super().__init__() - self.exported_program = exported_program - def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph counter = 0 diff --git a/backends/xnnpack/_passes/xnnpack_pass.py b/backends/xnnpack/_passes/xnnpack_pass.py index 47b7a9fe3d3..3e70b207f61 100644 --- a/backends/xnnpack/_passes/xnnpack_pass.py +++ b/backends/xnnpack/_passes/xnnpack_pass.py @@ -1,22 +1,16 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from executorch.exir.pass_base import ExportPass -from torch.export import ExportedProgram +from executorch.exir.pass_base import ExportPass, RequireExportedProgram -class XNNPACKPass(ExportPass): +class XNNPACKPass(RequireExportedProgram, ExportPass): """ An abstract interface for XNNPACK backend passes. """ - def __init__(self, exported_program: ExportedProgram) -> None: - super().__init__() - self._exported_program = exported_program - - @property - def exported_program(self) -> ExportedProgram: - return self._exported_program + ... diff --git a/exir/pass_base.py b/exir/pass_base.py index 497970fae34..0151aaebde0 100644 --- a/exir/pass_base.py +++ b/exir/pass_base.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -36,6 +37,7 @@ from torch._subclasses import FakeTensorMode, UnsupportedFakeTensorException from torch._subclasses.fake_tensor import FakeTensor from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode +from torch.export import ExportedProgram from torch.fx import traceback as fx_traceback from torch.fx.experimental.proxy_tensor import PythonKeyTracer from torch.fx.graph import CodeGen @@ -734,6 +736,24 @@ def migrate_meta_val( return res +class RequireExportedProgram: + """Mixin to require a pass to take an exported program, which is accessed by the exported_program property. + Note that the mixin needs to be added to the left of the pass class in the inheritance list to get a correct MRO. + """ + + def __init__(self, exported_program: ExportedProgram | None = None) -> None: + self._exported_program = exported_program + super().__init__() + + @property + def exported_program(self) -> ExportedProgram: + if self._exported_program is None: + raise ValueError( + "Tried to access exported_program, but it was not provided when constructing the pass." + ) + return self._exported_program + + @runtime_checkable class ArgSchema(Protocol): name: str From d72a7301c67d28f80036a42ffa8cdc1fb51b05c8 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Mon, 6 Oct 2025 14:03:51 +0200 Subject: [PATCH 2/2] Add PostToBackend stage to Recipes After a discussion in #14588, it was decided to create an additional recipe stage to run passes after partitioning. This is meant for backends that convert ops directly instead of partitioning. Signed-off-by: Erik Lundell Change-Id: Iddf74accb739d4dff16fa46c6fad88ffccfe2f3b --- export/export.py | 1 + export/recipe.py | 3 + export/stages.py | 119 ++++++++++++++++++++++++++--- export/tests/test_export_stages.py | 63 +++++++++++++++ export/types.py | 2 + 5 files changed, 179 insertions(+), 9 deletions(-) diff --git a/export/export.py b/export/export.py index 1e9cdbde7c0..c375f51ccac 100644 --- a/export/export.py +++ b/export/export.py @@ -182,6 +182,7 @@ def _get_default_pipeline(self) -> List[StageType]: StageType.QUANTIZE, # Optional stage, returns original model if quant recipe is invalid StageType.TORCH_EXPORT, StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.POST_TO_BACKEND, StageType.TO_EXECUTORCH, ] diff --git a/export/recipe.py b/export/recipe.py index 4465da51956..544fa59b325 100644 --- a/export/recipe.py +++ b/export/recipe.py @@ -17,6 +17,7 @@ from executorch.exir.backend.partitioner import Partitioner from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig +from executorch.exir.pass_base import ExportPass from executorch.exir.pass_manager import PassType from torchao.core.config import AOBaseConfig from torchao.quantization.pt2e.quantizer import Quantizer @@ -122,6 +123,7 @@ class LoweringRecipe: edge_transform_passes: Optional list of callables that take (method_name: str, exported_program: ExportedProgram) as arguments and return a list of passes (PassType) to be executed during lowering stages. edge_compile_config: Optional edge compilation configuration + post_to_backend_passes: Optional list of passes to run after all partitioners have ran. """ partitioners: Optional[List[Partitioner]] = None @@ -130,6 +132,7 @@ class LoweringRecipe: ) = None # pyre-ignore[11]: Type not defined edge_compile_config: Optional[EdgeCompileConfig] = None + post_to_backend_passes: list[PassType | type[ExportPass]] | None = None @experimental( diff --git a/export/stages.py b/export/stages.py index 3be801c6a14..07f20a180a3 100644 --- a/export/stages.py +++ b/export/stages.py @@ -9,12 +9,14 @@ import logging from abc import ABC, abstractmethod from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, cast, Dict, List, Optional import torch from executorch.devtools.backend_debug import get_delegation_info -from executorch.exir import EdgeCompileConfig, ExportedProgram +from executorch.exir import EdgeCompileConfig, EdgeProgramManager, ExportedProgram from executorch.exir.backend.backend_api import validation_disabled +from executorch.exir.pass_base import ExportPass, RequireExportedProgram +from executorch.exir.pass_manager import PassManager from executorch.exir.program import to_edge, to_edge_transform_and_lower from executorch.export.recipe import LoweringRecipe, QuantizationRecipe from executorch.export.types import StageType @@ -118,7 +120,7 @@ def __init__( self.strict = strict @property - def stage_type(self) -> str: + def stage_type(self) -> StageType: return StageType.TORCH_EXPORT @property @@ -197,7 +199,7 @@ def from_recipe( ) @property - def stage_type(self) -> str: + def stage_type(self) -> StageType: return StageType.TO_EDGE_TRANSFORM_AND_LOWER @property @@ -266,7 +268,7 @@ def __init__(self, backend_config: Any) -> None: self._backend_config = backend_config @property - def stage_type(self) -> str: + def stage_type(self) -> StageType: return StageType.TO_EXECUTORCH @property @@ -304,7 +306,7 @@ def __init__(self, quantization_recipe: Optional[QuantizationRecipe]) -> None: self._transformed_models: Dict[str, nn.Module] = {} @property - def stage_type(self) -> str: + def stage_type(self) -> StageType: return StageType.SOURCE_TRANSFORM @property @@ -358,7 +360,7 @@ def __init__(self, quantization_recipe: Optional[QuantizationRecipe]) -> None: self._quantization_recipe = quantization_recipe @property - def stage_type(self) -> str: + def stage_type(self) -> StageType: return StageType.QUANTIZE @property @@ -459,7 +461,7 @@ def from_recipe(cls, lowering_recipe: Optional["LoweringRecipe"]) -> "ToEdgeStag ) @property - def stage_type(self) -> str: + def stage_type(self) -> StageType: return StageType.TO_EDGE @property @@ -520,7 +522,7 @@ def from_recipe( ) @property - def stage_type(self) -> str: + def stage_type(self) -> StageType: return StageType.TO_BACKEND @property @@ -583,3 +585,102 @@ def delegation_info(self) -> Any: Returns the delegation info. """ return self._artifact.get_context("delegation_info") + + +class PostToBackendStage(Stage): + """ + Stage: Run passes after all partitioners have done their partitioning. + """ + + def __init__( + self, + pass_list_or_manager: ( + list[PassType | type[ExportPass]] | PassManager | None + ) = None, + edge_compile_config: EdgeCompileConfig | None = None, + ) -> None: + super().__init__() + if pass_list_or_manager is None: + pass_list_or_manager = [] + + self._pass_list_or_manager = pass_list_or_manager + self._edge_compile_config = edge_compile_config + + @classmethod + def from_recipe( + cls, lowering_recipe: Optional["LoweringRecipe"] + ) -> "PostToBackendStage": + if lowering_recipe is None: + return cls() + + return cls( + pass_list=lowering_recipe.post_to_backend_passes, + edge_compile_config=lowering_recipe.edge_compile_config, + ) + + @property + def stage_type(self) -> StageType: + return StageType.POST_TO_BACKEND + + @property + def valid_predecessor_stages(self) -> List["StageType"]: + return [StageType.TO_BACKEND, StageType.TO_EDGE_TRANSFORM_AND_LOWER] + + @property + def can_start_pipeline(self) -> bool: + return False + + def run(self, artifact: PipelineArtifact) -> None: + """ + Run list of passes using edge_program_manager.transform(). + + Args: + artifact: PipelineArtifact which's data field is expected to contain an edge_program_manager. + """ + + if self._pass_list_or_manager: + edge_program_manager = cast(EdgeProgramManager, artifact.data) + + if isinstance(self._pass_list_or_manager, PassManager): + edge_program_manager = edge_program_manager.transform( + self._pass_list_or_manager, self._edge_compile_config + ) + else: + exported_program = edge_program_manager.exported_program() + pass_instances: list[PassType] = [] + for _pass in self._pass_list_or_manager: + if isinstance(_pass, type): + if not issubclass(_pass, ExportPass): + raise RuntimeError( + f"Pass {_pass} was not subclass of ExportPass." + ) + if issubclass(_pass, RequireExportedProgram): + pass_instance = _pass( + exported_program=exported_program # type: ignore + ) + else: + pass_instance = _pass() + pass_instances.append(pass_instance) + else: + pass_instances.append(_pass) + + edge_program_manager = edge_program_manager.transform( + pass_instances, self._edge_compile_config + ) + # Get delegation info + delegation_info = get_delegation_info( + edge_program_manager.exported_program().graph_module + ) + + self._artifact = artifact.copy_with_new_data(edge_program_manager) + self._artifact.add_context("delegation_info", delegation_info) + else: + # If pass_list_or_manager is None or empty list, do nothing. + self._artifact = artifact + + @property + def delegation_info(self) -> Any: + """ + Returns the delegation info. + """ + return self._artifact.get_context("delegation_info") diff --git a/export/tests/test_export_stages.py b/export/tests/test_export_stages.py index 4e8144bd487..b3551bbb393 100644 --- a/export/tests/test_export_stages.py +++ b/export/tests/test_export_stages.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -10,12 +11,14 @@ from unittest.mock import Mock, patch, PropertyMock import torch +from executorch.exir.pass_base import ExportPass, RequireExportedProgram from executorch.exir.program import EdgeProgramManager, ExecutorchProgramManager from executorch.export import AOQuantizationConfig, QuantizationRecipe, StageType from executorch.export.stages import ( EdgeTransformAndLowerStage, ExecutorchStage, PipelineArtifact, + PostToBackendStage, QuantizeStage, SourceTransformStage, ToBackendStage, @@ -35,6 +38,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) +class DummyExportPassWithProgram(RequireExportedProgram, ExportPass): + def __init__(self, exported_program: ExportedProgram) -> None: + super().__init__() + self.received_exported_program = exported_program + + +class DummyExportPassWithoutProgram(ExportPass): + def __init__(self) -> None: + super().__init__() + self.initialized = True + + class TestPipelineArtifact(unittest.TestCase): def test_copy_with_new_data(self) -> None: @@ -587,3 +602,51 @@ def test_run_edge_manager_none(self) -> None: with self.assertRaises(RuntimeError) as cm: stage.run(artifact) self.assertIn("Edge program manager is not set", str(cm.exception)) + + +class TestPostToBackendStage(unittest.TestCase): + @patch("executorch.export.stages.get_delegation_info") + def test_run_with_mixed_pass_types(self, mock_get_delegation_info: Mock) -> None: + mock_get_delegation_info.return_value = {"delegation": "info"} + + exported_program = Mock(spec=ExportedProgram) + edge_program_manager = Mock(spec=EdgeProgramManager) + transformed_manager = Mock(spec=EdgeProgramManager) + transformed_exported_program = Mock(spec=ExportedProgram) + transformed_graph_module = Mock() + transformed_exported_program.graph_module = transformed_graph_module + + edge_program_manager.exported_program.return_value = exported_program + edge_program_manager.transform.return_value = transformed_manager + transformed_manager.exported_program.return_value = transformed_exported_program + + passthrough_pass = Mock() + non_program_pass_instance = DummyExportPassWithoutProgram() + + stage = PostToBackendStage( + pass_list_or_manager=[ + passthrough_pass, + DummyExportPassWithProgram, + DummyExportPassWithoutProgram, + non_program_pass_instance, + ] + ) + + artifact = PipelineArtifact(data=edge_program_manager, context={}) + stage.run(artifact) + + edge_program_manager.transform.assert_called_once() + pass_instances, compile_config = edge_program_manager.transform.call_args[0] + + self.assertEqual(len(pass_instances), 4) + self.assertIs(pass_instances[0], passthrough_pass) + self.assertIsInstance(pass_instances[1], DummyExportPassWithProgram) + self.assertIsInstance(pass_instances[2], DummyExportPassWithoutProgram) + self.assertIs(pass_instances[1].received_exported_program, exported_program) + self.assertIs(pass_instances[3], non_program_pass_instance) + self.assertIsNone(compile_config) + + result_artifact = stage.get_artifacts() + self.assertIs(result_artifact.data, transformed_manager) + self.assertEqual(stage.delegation_info, {"delegation": "info"}) + mock_get_delegation_info.assert_called_once_with(transformed_graph_module) diff --git a/export/types.py b/export/types.py index 760f8461d41..89b46608ffc 100644 --- a/export/types.py +++ b/export/types.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -18,4 +19,5 @@ class StageType(str, Enum): TO_EDGE_TRANSFORM_AND_LOWER = "to_edge_transform_and_lower" TO_EDGE = "to_edge" TO_BACKEND = "to_backend" + POST_TO_BACKEND = "post_to_backend" TO_EXECUTORCH = "to_executorch"