Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/arm/_passes/arm_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from abc import abstractmethod
from typing import List, Optional, Set, Type

from executorch.exir.pass_base import ExportPass, NodeMetadata
from executorch.exir.pass_base import ExportPass, NodeMetadata, RequireExportedProgram


class ArmPass(ExportPass):
class ArmPass(RequireExportedProgram, ExportPass):
"""Base class for Arm passes"""

@property
Expand Down
9 changes: 3 additions & 6 deletions backends/transforms/fuse_batch_norm_with_conv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -11,23 +12,19 @@
from executorch.backends.transforms.utils import get_param_tensor, is_param_node
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.pass_base import ExportPass, PassResult, RequireExportedProgram

from torch.nn.utils.fusion import fuse_conv_bn_weights


class FuseBatchNormWithConvPass(ExportPass):
class FuseBatchNormWithConvPass(RequireExportedProgram, ExportPass):
"""
Batch Norm can be implemented using 1x1 Depthwise Convolution. However doing so will increase
memory usage since we serialize new weights to represent the convolution. In most cases,
Batch norm is used after convolution. The 1x1 depthwise convolution can then be fused
with the previous convolution
"""

def __init__(self, exported_program: ExportedProgram):
super().__init__()
self.exported_program = exported_program

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
counter = 0
Expand Down
14 changes: 4 additions & 10 deletions backends/xnnpack/_passes/xnnpack_pass.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.exir.pass_base import ExportPass
from torch.export import ExportedProgram
from executorch.exir.pass_base import ExportPass, RequireExportedProgram


class XNNPACKPass(ExportPass):
class XNNPACKPass(RequireExportedProgram, ExportPass):
"""
An abstract interface for XNNPACK backend passes.
"""

def __init__(self, exported_program: ExportedProgram) -> None:
super().__init__()
self._exported_program = exported_program

@property
def exported_program(self) -> ExportedProgram:
return self._exported_program
...
20 changes: 20 additions & 0 deletions exir/pass_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -36,6 +37,7 @@
from torch._subclasses import FakeTensorMode, UnsupportedFakeTensorException
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
from torch.export import ExportedProgram
from torch.fx import traceback as fx_traceback
from torch.fx.experimental.proxy_tensor import PythonKeyTracer
from torch.fx.graph import CodeGen
Expand Down Expand Up @@ -734,6 +736,24 @@ def migrate_meta_val(
return res


class RequireExportedProgram:
"""Mixin to require a pass to take an exported program, which is accessed by the exported_program property.
Copy link
Contributor

@abhinaykukkadapu abhinaykukkadapu Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Erik-Lundell @zingo Sorry for the late response and thanks for following up on this change, i've done some "standardization" when passes needs an ExportedProgram. QNN backend has exact issue where some passes wants ExportedProgram, i think we can follow similar pattern here too, here is the code ptr: https://github.com/pytorch/executorch/blob/main/backends/qualcomm/recipes/qnn_recipe_provider.py#L142

Basic idea is that the recipe will get a callable similar to aten_transform_passes and the callable can have a sequence of passes instantiated with or without EP. Let me know if you have any questions.

After you are done, i will probably take a stab at creating a single stage for all the pass related steps and then that stage shall be plugged in at various stages pre/post et stages.

Note that the mixin needs to be added to the left of the pass class in the inheritance list to get a correct MRO.
"""

def __init__(self, exported_program: ExportedProgram | None = None) -> None:
self._exported_program = exported_program
super().__init__()

@property
def exported_program(self) -> ExportedProgram:
if self._exported_program is None:
raise ValueError(
"Tried to access exported_program, but it was not provided when constructing the pass."
)
return self._exported_program


@runtime_checkable
class ArgSchema(Protocol):
name: str
Expand Down
1 change: 1 addition & 0 deletions export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def _get_default_pipeline(self) -> List[StageType]:
StageType.QUANTIZE, # Optional stage, returns original model if quant recipe is invalid
StageType.TORCH_EXPORT,
StageType.TO_EDGE_TRANSFORM_AND_LOWER,
StageType.POST_TO_BACKEND,
StageType.TO_EXECUTORCH,
]

Expand Down
3 changes: 3 additions & 0 deletions export/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from executorch.exir.backend.partitioner import Partitioner
from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig
from executorch.exir.pass_base import ExportPass
from executorch.exir.pass_manager import PassType
from torchao.core.config import AOBaseConfig
from torchao.quantization.pt2e.quantizer import Quantizer
Expand Down Expand Up @@ -122,6 +123,7 @@ class LoweringRecipe:
edge_transform_passes: Optional list of callables that take (method_name: str, exported_program: ExportedProgram) as arguments
and return a list of passes (PassType) to be executed during lowering stages.
edge_compile_config: Optional edge compilation configuration
post_to_backend_passes: Optional list of passes to run after all partitioners have ran.
"""

partitioners: Optional[List[Partitioner]] = None
Expand All @@ -130,6 +132,7 @@ class LoweringRecipe:
) = None
# pyre-ignore[11]: Type not defined
edge_compile_config: Optional[EdgeCompileConfig] = None
post_to_backend_passes: list[PassType | type[ExportPass]] | None = None


@experimental(
Expand Down
119 changes: 110 additions & 9 deletions export/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, cast, Dict, List, Optional

import torch
from executorch.devtools.backend_debug import get_delegation_info
from executorch.exir import EdgeCompileConfig, ExportedProgram
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, ExportedProgram
from executorch.exir.backend.backend_api import validation_disabled
from executorch.exir.pass_base import ExportPass, RequireExportedProgram
from executorch.exir.pass_manager import PassManager
from executorch.exir.program import to_edge, to_edge_transform_and_lower
from executorch.export.recipe import LoweringRecipe, QuantizationRecipe
from executorch.export.types import StageType
Expand Down Expand Up @@ -118,7 +120,7 @@ def __init__(
self.strict = strict

@property
def stage_type(self) -> str:
def stage_type(self) -> StageType:
return StageType.TORCH_EXPORT

@property
Expand Down Expand Up @@ -197,7 +199,7 @@ def from_recipe(
)

@property
def stage_type(self) -> str:
def stage_type(self) -> StageType:
return StageType.TO_EDGE_TRANSFORM_AND_LOWER

@property
Expand Down Expand Up @@ -266,7 +268,7 @@ def __init__(self, backend_config: Any) -> None:
self._backend_config = backend_config

@property
def stage_type(self) -> str:
def stage_type(self) -> StageType:
return StageType.TO_EXECUTORCH

@property
Expand Down Expand Up @@ -304,7 +306,7 @@ def __init__(self, quantization_recipe: Optional[QuantizationRecipe]) -> None:
self._transformed_models: Dict[str, nn.Module] = {}

@property
def stage_type(self) -> str:
def stage_type(self) -> StageType:
return StageType.SOURCE_TRANSFORM

@property
Expand Down Expand Up @@ -358,7 +360,7 @@ def __init__(self, quantization_recipe: Optional[QuantizationRecipe]) -> None:
self._quantization_recipe = quantization_recipe

@property
def stage_type(self) -> str:
def stage_type(self) -> StageType:
return StageType.QUANTIZE

@property
Expand Down Expand Up @@ -459,7 +461,7 @@ def from_recipe(cls, lowering_recipe: Optional["LoweringRecipe"]) -> "ToEdgeStag
)

@property
def stage_type(self) -> str:
def stage_type(self) -> StageType:
return StageType.TO_EDGE

@property
Expand Down Expand Up @@ -520,7 +522,7 @@ def from_recipe(
)

@property
def stage_type(self) -> str:
def stage_type(self) -> StageType:
return StageType.TO_BACKEND

@property
Expand Down Expand Up @@ -583,3 +585,102 @@ def delegation_info(self) -> Any:
Returns the delegation info.
"""
return self._artifact.get_context("delegation_info")


class PostToBackendStage(Stage):
"""
Stage: Run passes after all partitioners have done their partitioning.
"""

def __init__(
self,
pass_list_or_manager: (
list[PassType | type[ExportPass]] | PassManager | None
) = None,
edge_compile_config: EdgeCompileConfig | None = None,
) -> None:
super().__init__()
if pass_list_or_manager is None:
pass_list_or_manager = []

self._pass_list_or_manager = pass_list_or_manager
self._edge_compile_config = edge_compile_config

@classmethod
def from_recipe(
cls, lowering_recipe: Optional["LoweringRecipe"]
) -> "PostToBackendStage":
if lowering_recipe is None:
return cls()

return cls(
pass_list=lowering_recipe.post_to_backend_passes,
edge_compile_config=lowering_recipe.edge_compile_config,
)

@property
def stage_type(self) -> StageType:
return StageType.POST_TO_BACKEND

@property
def valid_predecessor_stages(self) -> List["StageType"]:
return [StageType.TO_BACKEND, StageType.TO_EDGE_TRANSFORM_AND_LOWER]

@property
def can_start_pipeline(self) -> bool:
return False

def run(self, artifact: PipelineArtifact) -> None:
"""
Run list of passes using edge_program_manager.transform().

Args:
artifact: PipelineArtifact which's data field is expected to contain an edge_program_manager.
"""

if self._pass_list_or_manager:
edge_program_manager = cast(EdgeProgramManager, artifact.data)

if isinstance(self._pass_list_or_manager, PassManager):
edge_program_manager = edge_program_manager.transform(
self._pass_list_or_manager, self._edge_compile_config
)
else:
exported_program = edge_program_manager.exported_program()
pass_instances: list[PassType] = []
for _pass in self._pass_list_or_manager:
if isinstance(_pass, type):
if not issubclass(_pass, ExportPass):
raise RuntimeError(
f"Pass {_pass} was not subclass of ExportPass."
)
if issubclass(_pass, RequireExportedProgram):
pass_instance = _pass(
exported_program=exported_program # type: ignore
)
else:
pass_instance = _pass()
pass_instances.append(pass_instance)
else:
pass_instances.append(_pass)

edge_program_manager = edge_program_manager.transform(
pass_instances, self._edge_compile_config
)
# Get delegation info
delegation_info = get_delegation_info(
edge_program_manager.exported_program().graph_module
)

self._artifact = artifact.copy_with_new_data(edge_program_manager)
self._artifact.add_context("delegation_info", delegation_info)
else:
# If pass_list_or_manager is None or empty list, do nothing.
self._artifact = artifact

@property
def delegation_info(self) -> Any:
"""
Returns the delegation info.
"""
return self._artifact.get_context("delegation_info")
Loading
Loading