Skip to content

Commit d72a730

Browse files
committed
Add PostToBackend stage to Recipes
After a discussion in #14588, it was decided to create an additional recipe stage to run passes after partitioning. This is meant for backends that convert ops directly instead of partitioning. Signed-off-by: Erik Lundell <[email protected]> Change-Id: Iddf74accb739d4dff16fa46c6fad88ffccfe2f3b
1 parent f9f0c69 commit d72a730

File tree

5 files changed

+179
-9
lines changed

5 files changed

+179
-9
lines changed

export/export.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def _get_default_pipeline(self) -> List[StageType]:
182182
StageType.QUANTIZE, # Optional stage, returns original model if quant recipe is invalid
183183
StageType.TORCH_EXPORT,
184184
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
185+
StageType.POST_TO_BACKEND,
185186
StageType.TO_EXECUTORCH,
186187
]
187188

export/recipe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from executorch.exir.backend.partitioner import Partitioner
1919
from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig
20+
from executorch.exir.pass_base import ExportPass
2021
from executorch.exir.pass_manager import PassType
2122
from torchao.core.config import AOBaseConfig
2223
from torchao.quantization.pt2e.quantizer import Quantizer
@@ -122,6 +123,7 @@ class LoweringRecipe:
122123
edge_transform_passes: Optional list of callables that take (method_name: str, exported_program: ExportedProgram) as arguments
123124
and return a list of passes (PassType) to be executed during lowering stages.
124125
edge_compile_config: Optional edge compilation configuration
126+
post_to_backend_passes: Optional list of passes to run after all partitioners have ran.
125127
"""
126128

127129
partitioners: Optional[List[Partitioner]] = None
@@ -130,6 +132,7 @@ class LoweringRecipe:
130132
) = None
131133
# pyre-ignore[11]: Type not defined
132134
edge_compile_config: Optional[EdgeCompileConfig] = None
135+
post_to_backend_passes: list[PassType | type[ExportPass]] | None = None
133136

134137

135138
@experimental(

export/stages.py

Lines changed: 110 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
import logging
1010
from abc import ABC, abstractmethod
1111
from collections import defaultdict
12-
from typing import Any, Callable, Dict, List, Optional
12+
from typing import Any, Callable, cast, Dict, List, Optional
1313

1414
import torch
1515
from executorch.devtools.backend_debug import get_delegation_info
16-
from executorch.exir import EdgeCompileConfig, ExportedProgram
16+
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, ExportedProgram
1717
from executorch.exir.backend.backend_api import validation_disabled
18+
from executorch.exir.pass_base import ExportPass, RequireExportedProgram
19+
from executorch.exir.pass_manager import PassManager
1820
from executorch.exir.program import to_edge, to_edge_transform_and_lower
1921
from executorch.export.recipe import LoweringRecipe, QuantizationRecipe
2022
from executorch.export.types import StageType
@@ -118,7 +120,7 @@ def __init__(
118120
self.strict = strict
119121

120122
@property
121-
def stage_type(self) -> str:
123+
def stage_type(self) -> StageType:
122124
return StageType.TORCH_EXPORT
123125

124126
@property
@@ -197,7 +199,7 @@ def from_recipe(
197199
)
198200

199201
@property
200-
def stage_type(self) -> str:
202+
def stage_type(self) -> StageType:
201203
return StageType.TO_EDGE_TRANSFORM_AND_LOWER
202204

203205
@property
@@ -266,7 +268,7 @@ def __init__(self, backend_config: Any) -> None:
266268
self._backend_config = backend_config
267269

268270
@property
269-
def stage_type(self) -> str:
271+
def stage_type(self) -> StageType:
270272
return StageType.TO_EXECUTORCH
271273

272274
@property
@@ -304,7 +306,7 @@ def __init__(self, quantization_recipe: Optional[QuantizationRecipe]) -> None:
304306
self._transformed_models: Dict[str, nn.Module] = {}
305307

306308
@property
307-
def stage_type(self) -> str:
309+
def stage_type(self) -> StageType:
308310
return StageType.SOURCE_TRANSFORM
309311

310312
@property
@@ -358,7 +360,7 @@ def __init__(self, quantization_recipe: Optional[QuantizationRecipe]) -> None:
358360
self._quantization_recipe = quantization_recipe
359361

360362
@property
361-
def stage_type(self) -> str:
363+
def stage_type(self) -> StageType:
362364
return StageType.QUANTIZE
363365

364366
@property
@@ -459,7 +461,7 @@ def from_recipe(cls, lowering_recipe: Optional["LoweringRecipe"]) -> "ToEdgeStag
459461
)
460462

461463
@property
462-
def stage_type(self) -> str:
464+
def stage_type(self) -> StageType:
463465
return StageType.TO_EDGE
464466

465467
@property
@@ -520,7 +522,7 @@ def from_recipe(
520522
)
521523

522524
@property
523-
def stage_type(self) -> str:
525+
def stage_type(self) -> StageType:
524526
return StageType.TO_BACKEND
525527

526528
@property
@@ -583,3 +585,102 @@ def delegation_info(self) -> Any:
583585
Returns the delegation info.
584586
"""
585587
return self._artifact.get_context("delegation_info")
588+
589+
590+
class PostToBackendStage(Stage):
591+
"""
592+
Stage: Run passes after all partitioners have done their partitioning.
593+
"""
594+
595+
def __init__(
596+
self,
597+
pass_list_or_manager: (
598+
list[PassType | type[ExportPass]] | PassManager | None
599+
) = None,
600+
edge_compile_config: EdgeCompileConfig | None = None,
601+
) -> None:
602+
super().__init__()
603+
if pass_list_or_manager is None:
604+
pass_list_or_manager = []
605+
606+
self._pass_list_or_manager = pass_list_or_manager
607+
self._edge_compile_config = edge_compile_config
608+
609+
@classmethod
610+
def from_recipe(
611+
cls, lowering_recipe: Optional["LoweringRecipe"]
612+
) -> "PostToBackendStage":
613+
if lowering_recipe is None:
614+
return cls()
615+
616+
return cls(
617+
pass_list=lowering_recipe.post_to_backend_passes,
618+
edge_compile_config=lowering_recipe.edge_compile_config,
619+
)
620+
621+
@property
622+
def stage_type(self) -> StageType:
623+
return StageType.POST_TO_BACKEND
624+
625+
@property
626+
def valid_predecessor_stages(self) -> List["StageType"]:
627+
return [StageType.TO_BACKEND, StageType.TO_EDGE_TRANSFORM_AND_LOWER]
628+
629+
@property
630+
def can_start_pipeline(self) -> bool:
631+
return False
632+
633+
def run(self, artifact: PipelineArtifact) -> None:
634+
"""
635+
Run list of passes using edge_program_manager.transform().
636+
637+
Args:
638+
artifact: PipelineArtifact which's data field is expected to contain an edge_program_manager.
639+
"""
640+
641+
if self._pass_list_or_manager:
642+
edge_program_manager = cast(EdgeProgramManager, artifact.data)
643+
644+
if isinstance(self._pass_list_or_manager, PassManager):
645+
edge_program_manager = edge_program_manager.transform(
646+
self._pass_list_or_manager, self._edge_compile_config
647+
)
648+
else:
649+
exported_program = edge_program_manager.exported_program()
650+
pass_instances: list[PassType] = []
651+
for _pass in self._pass_list_or_manager:
652+
if isinstance(_pass, type):
653+
if not issubclass(_pass, ExportPass):
654+
raise RuntimeError(
655+
f"Pass {_pass} was not subclass of ExportPass."
656+
)
657+
if issubclass(_pass, RequireExportedProgram):
658+
pass_instance = _pass(
659+
exported_program=exported_program # type: ignore
660+
)
661+
else:
662+
pass_instance = _pass()
663+
pass_instances.append(pass_instance)
664+
else:
665+
pass_instances.append(_pass)
666+
667+
edge_program_manager = edge_program_manager.transform(
668+
pass_instances, self._edge_compile_config
669+
)
670+
# Get delegation info
671+
delegation_info = get_delegation_info(
672+
edge_program_manager.exported_program().graph_module
673+
)
674+
675+
self._artifact = artifact.copy_with_new_data(edge_program_manager)
676+
self._artifact.add_context("delegation_info", delegation_info)
677+
else:
678+
# If pass_list_or_manager is None or empty list, do nothing.
679+
self._artifact = artifact
680+
681+
@property
682+
def delegation_info(self) -> Any:
683+
"""
684+
Returns the delegation info.
685+
"""
686+
return self._artifact.get_context("delegation_info")

export/tests/test_export_stages.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -10,12 +11,14 @@
1011
from unittest.mock import Mock, patch, PropertyMock
1112

1213
import torch
14+
from executorch.exir.pass_base import ExportPass, RequireExportedProgram
1315
from executorch.exir.program import EdgeProgramManager, ExecutorchProgramManager
1416
from executorch.export import AOQuantizationConfig, QuantizationRecipe, StageType
1517
from executorch.export.stages import (
1618
EdgeTransformAndLowerStage,
1719
ExecutorchStage,
1820
PipelineArtifact,
21+
PostToBackendStage,
1922
QuantizeStage,
2023
SourceTransformStage,
2124
ToBackendStage,
@@ -35,6 +38,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
3538
return self.linear(x)
3639

3740

41+
class DummyExportPassWithProgram(RequireExportedProgram, ExportPass):
42+
def __init__(self, exported_program: ExportedProgram) -> None:
43+
super().__init__()
44+
self.received_exported_program = exported_program
45+
46+
47+
class DummyExportPassWithoutProgram(ExportPass):
48+
def __init__(self) -> None:
49+
super().__init__()
50+
self.initialized = True
51+
52+
3853
class TestPipelineArtifact(unittest.TestCase):
3954

4055
def test_copy_with_new_data(self) -> None:
@@ -587,3 +602,51 @@ def test_run_edge_manager_none(self) -> None:
587602
with self.assertRaises(RuntimeError) as cm:
588603
stage.run(artifact)
589604
self.assertIn("Edge program manager is not set", str(cm.exception))
605+
606+
607+
class TestPostToBackendStage(unittest.TestCase):
608+
@patch("executorch.export.stages.get_delegation_info")
609+
def test_run_with_mixed_pass_types(self, mock_get_delegation_info: Mock) -> None:
610+
mock_get_delegation_info.return_value = {"delegation": "info"}
611+
612+
exported_program = Mock(spec=ExportedProgram)
613+
edge_program_manager = Mock(spec=EdgeProgramManager)
614+
transformed_manager = Mock(spec=EdgeProgramManager)
615+
transformed_exported_program = Mock(spec=ExportedProgram)
616+
transformed_graph_module = Mock()
617+
transformed_exported_program.graph_module = transformed_graph_module
618+
619+
edge_program_manager.exported_program.return_value = exported_program
620+
edge_program_manager.transform.return_value = transformed_manager
621+
transformed_manager.exported_program.return_value = transformed_exported_program
622+
623+
passthrough_pass = Mock()
624+
non_program_pass_instance = DummyExportPassWithoutProgram()
625+
626+
stage = PostToBackendStage(
627+
pass_list_or_manager=[
628+
passthrough_pass,
629+
DummyExportPassWithProgram,
630+
DummyExportPassWithoutProgram,
631+
non_program_pass_instance,
632+
]
633+
)
634+
635+
artifact = PipelineArtifact(data=edge_program_manager, context={})
636+
stage.run(artifact)
637+
638+
edge_program_manager.transform.assert_called_once()
639+
pass_instances, compile_config = edge_program_manager.transform.call_args[0]
640+
641+
self.assertEqual(len(pass_instances), 4)
642+
self.assertIs(pass_instances[0], passthrough_pass)
643+
self.assertIsInstance(pass_instances[1], DummyExportPassWithProgram)
644+
self.assertIsInstance(pass_instances[2], DummyExportPassWithoutProgram)
645+
self.assertIs(pass_instances[1].received_exported_program, exported_program)
646+
self.assertIs(pass_instances[3], non_program_pass_instance)
647+
self.assertIsNone(compile_config)
648+
649+
result_artifact = stage.get_artifacts()
650+
self.assertIs(result_artifact.data, transformed_manager)
651+
self.assertEqual(stage.delegation_info, {"delegation": "info"})
652+
mock_get_delegation_info.assert_called_once_with(transformed_graph_module)

export/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -18,4 +19,5 @@ class StageType(str, Enum):
1819
TO_EDGE_TRANSFORM_AND_LOWER = "to_edge_transform_and_lower"
1920
TO_EDGE = "to_edge"
2021
TO_BACKEND = "to_backend"
22+
POST_TO_BACKEND = "post_to_backend"
2123
TO_EXECUTORCH = "to_executorch"

0 commit comments

Comments
 (0)