Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions backends/aoti/aoti_partitioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Dict, List, Optional, Tuple

import torch
from executorch.exir._warnings import experimental
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
from torch.export.exported_program import ExportedProgram


@experimental(
"This API and all of cuda backend related functionality are experimental."
)
class AotiPartitioner(Partitioner):
"""
Base partitioner for AOTInductor-driven backend integration.

This partitioner creates a single partition containing all operators from the input graph.
It skips core ATen decomposition, allowing the backend to handle decomposition using
AOTInductor's backend-specific decomposition table.

Only operators that cannot be handled by the aoti library will be excluded from
the partition and fall back to ExecuTorch's default or custom handling.
"""

def __init__(self, backend_name: str, compile_spec: List[CompileSpec]) -> None:
"""
Initialize the AOTI partitioner.

Args:
backend_name: The name of the backend (e.g., "CudaBackend", "MetalBackend")
compile_spec: List of compilation specifications
"""
self.delegation_spec = DelegationSpec(backend_name, compile_spec)

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
"""
Fully delegate the graph to AOTInductor by tagging all nodes as a single partition.
"""

partition_tags: Dict[str, DelegationSpec] = {}
tag = "tag0"

for node in exported_program.graph.nodes:
if node.op != "call_function":
continue
node.meta["delegation_tag"] = tag

partition_tags[tag] = self.delegation_spec

tag_constant_data(exported_program)
tag_mutated_buffer(exported_program)

return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)

def ops_to_not_decompose(
self, ep: ExportedProgram
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
"""
Return a list of operations that should not be decomposed and let the AOT compiler handle them.
Currently we skip ATen decompositon for all ops, and let the backend handle them.
"""
do_not_decompose = set()

for node in ep.graph.nodes:
if node.op == "call_function" and isinstance(
node.target, torch._ops.OpOverload
):
do_not_decompose.add(node.target)
return list(do_not_decompose), None
62 changes: 5 additions & 57 deletions backends/apple/metal/metal_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,74 +4,22 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Dict, final, List, Optional, Tuple
from typing import final, List

import torch
from executorch.backends.aoti.aoti_partitioner import AotiPartitioner
from executorch.backends.apple.metal.metal_backend import MetalBackend # usort: skip
from executorch.exir._warnings import experimental
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
from torch.export.exported_program import ExportedProgram


@final
@experimental(
"This API and all of Metal backend related functionality are experimental."
)
class MetalPartitioner(Partitioner):
class MetalPartitioner(AotiPartitioner):
"""
Metal partitioner for AOTInductor backend integration.

This partitioner creates a single partition containing all operators from the input graph.
It skips core ATen decomposition, allowing the Metal backend to handle decomposition using
AOTInductor's MPS-specific decomposition table.

Only operators that cannot be handled by the aoti-mps library will be excluded from
the partition and fall back to ExecuTorch's default or custom handling.
Metal partitioner driven by AOTInductor backend.
"""

def __init__(self, compile_spec: List[CompileSpec]) -> None:
self.delegation_spec = DelegationSpec(MetalBackend.__name__, compile_spec)

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
"""
Fully delegate the graph to AOTInductor by tagging all nodes as a single partition.
"""

partition_tags: Dict[str, DelegationSpec] = {}
tag = "tag0"

for node in exported_program.graph.nodes:
if node.op != "call_function":
continue
node.meta["delegation_tag"] = tag

partition_tags[tag] = self.delegation_spec

tag_constant_data(exported_program)
tag_mutated_buffer(exported_program)

return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)

def ops_to_not_decompose(
self, ep: ExportedProgram
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
"""
Return a list of operations that should not be decomposed and let the AOT compiler handle them.
Currently we skip ATen decompositon for all ops, and let the Metal backend handle them.
"""
do_not_decompose = set()

for node in ep.graph.nodes:
if node.op == "call_function" and isinstance(
node.target, torch._ops.OpOverload
):
do_not_decompose.add(node.target)
return list(do_not_decompose), None
super().__init__(MetalBackend.__name__, compile_spec)
62 changes: 5 additions & 57 deletions backends/cuda/cuda_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,74 +4,22 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Dict, final, List, Optional, Tuple
from typing import final, List

import torch
from executorch.backends.aoti.aoti_partitioner import AotiPartitioner
from executorch.backends.cuda.cuda_backend import CudaBackend # usort: skip
from executorch.exir._warnings import experimental
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
from torch.export.exported_program import ExportedProgram


@final
@experimental(
"This API and all of cuda backend related functionality are experimental."
)
class CudaPartitioner(Partitioner):
class CudaPartitioner(AotiPartitioner):
"""
CUDA partitioner for AOTInductor backend integration.

This partitioner creates a single partition containing all operators from the input graph.
It skips core ATen decomposition, allowing the CUDA backend to handle decomposition using
AOTInductor's CUDA-specific decomposition table.

Only operators that cannot be handled by the aoti-cuda library will be excluded from
the partition and fall back to ExecuTorch's default or custom handling.
CUDA partitioner driven by AOTInductor backend.
"""

def __init__(self, compile_spec: List[CompileSpec]) -> None:
self.delegation_spec = DelegationSpec(CudaBackend.__name__, compile_spec)

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
"""
Fully delegate the graph to AOTInductor by tagging all nodes as a single partition.
"""

partition_tags: Dict[str, DelegationSpec] = {}
tag = "tag0"

for node in exported_program.graph.nodes:
if node.op != "call_function":
continue
node.meta["delegation_tag"] = tag

partition_tags[tag] = self.delegation_spec

tag_constant_data(exported_program)
tag_mutated_buffer(exported_program)

return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)

def ops_to_not_decompose(
self, ep: ExportedProgram
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
"""
Return a list of operations that should not be decomposed and let the AOT compiler handle them.
Currently we skip ATen decompositon for all ops, and let the cuda backend handle them.
"""
do_not_decompose = set()

for node in ep.graph.nodes:
if node.op == "call_function" and isinstance(
node.target, torch._ops.OpOverload
):
do_not_decompose.add(node.target)
return list(do_not_decompose), None
super().__init__(CudaBackend.__name__, compile_spec)
Loading