diff --git a/backends/aoti/aoti_partitioner.py b/backends/aoti/aoti_partitioner.py new file mode 100644 index 00000000000..2ebb322dbd6 --- /dev/null +++ b/backends/aoti/aoti_partitioner.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Dict, List, Optional, Tuple + +import torch +from executorch.exir._warnings import experimental +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer +from torch.export.exported_program import ExportedProgram + + +@experimental( + "This API and all of cuda backend related functionality are experimental." +) +class AotiPartitioner(Partitioner): + """ + Base partitioner for AOTInductor-driven backend integration. + + This partitioner creates a single partition containing all operators from the input graph. + It skips core ATen decomposition, allowing the backend to handle decomposition using + AOTInductor's backend-specific decomposition table. + + Only operators that cannot be handled by the aoti library will be excluded from + the partition and fall back to ExecuTorch's default or custom handling. + """ + + def __init__(self, backend_name: str, compile_spec: List[CompileSpec]) -> None: + """ + Initialize the AOTI partitioner. + + Args: + backend_name: The name of the backend (e.g., "CudaBackend", "MetalBackend") + compile_spec: List of compilation specifications + """ + self.delegation_spec = DelegationSpec(backend_name, compile_spec) + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + """ + Fully delegate the graph to AOTInductor by tagging all nodes as a single partition. + """ + + partition_tags: Dict[str, DelegationSpec] = {} + tag = "tag0" + + for node in exported_program.graph.nodes: + if node.op != "call_function": + continue + node.meta["delegation_tag"] = tag + + partition_tags[tag] = self.delegation_spec + + tag_constant_data(exported_program) + tag_mutated_buffer(exported_program) + + return PartitionResult( + tagged_exported_program=exported_program, partition_tags=partition_tags + ) + + def ops_to_not_decompose( + self, ep: ExportedProgram + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + """ + Return a list of operations that should not be decomposed and let the AOT compiler handle them. + Currently we skip ATen decompositon for all ops, and let the backend handle them. + """ + do_not_decompose = set() + + for node in ep.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.OpOverload + ): + do_not_decompose.add(node.target) + return list(do_not_decompose), None diff --git a/backends/aoti/targets.bzl b/backends/aoti/targets.bzl index b7386403679..560cf52e06f 100644 --- a/backends/aoti/targets.bzl +++ b/backends/aoti/targets.bzl @@ -1,6 +1,21 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") def define_common_targets(): + runtime.python_library( + name = "aoti_partitioner", + srcs = [ + "aoti_partitioner.py", + ], + visibility = [ + "//executorch/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir/backend:partitioner", + "//executorch/exir/backend:utils", + ], + ) + # AOTI common shims functionality runtime.cxx_library( name = "common_shims", diff --git a/backends/apple/metal/metal_partitioner.py b/backends/apple/metal/metal_partitioner.py index b103ac0f455..e2672f6b554 100644 --- a/backends/apple/metal/metal_partitioner.py +++ b/backends/apple/metal/metal_partitioner.py @@ -4,74 +4,22 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable, Dict, final, List, Optional, Tuple +from typing import final, List -import torch +from executorch.backends.aoti.aoti_partitioner import AotiPartitioner from executorch.backends.apple.metal.metal_backend import MetalBackend # usort: skip from executorch.exir._warnings import experimental from executorch.exir.backend.compile_spec_schema import CompileSpec -from executorch.exir.backend.partitioner import ( - DelegationSpec, - Partitioner, - PartitionResult, -) -from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer -from torch.export.exported_program import ExportedProgram @final @experimental( "This API and all of Metal backend related functionality are experimental." ) -class MetalPartitioner(Partitioner): +class MetalPartitioner(AotiPartitioner): """ - Metal partitioner for AOTInductor backend integration. - - This partitioner creates a single partition containing all operators from the input graph. - It skips core ATen decomposition, allowing the Metal backend to handle decomposition using - AOTInductor's MPS-specific decomposition table. - - Only operators that cannot be handled by the aoti-mps library will be excluded from - the partition and fall back to ExecuTorch's default or custom handling. + Metal partitioner driven by AOTInductor backend. """ def __init__(self, compile_spec: List[CompileSpec]) -> None: - self.delegation_spec = DelegationSpec(MetalBackend.__name__, compile_spec) - - def partition(self, exported_program: ExportedProgram) -> PartitionResult: - """ - Fully delegate the graph to AOTInductor by tagging all nodes as a single partition. - """ - - partition_tags: Dict[str, DelegationSpec] = {} - tag = "tag0" - - for node in exported_program.graph.nodes: - if node.op != "call_function": - continue - node.meta["delegation_tag"] = tag - - partition_tags[tag] = self.delegation_spec - - tag_constant_data(exported_program) - tag_mutated_buffer(exported_program) - - return PartitionResult( - tagged_exported_program=exported_program, partition_tags=partition_tags - ) - - def ops_to_not_decompose( - self, ep: ExportedProgram - ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: - """ - Return a list of operations that should not be decomposed and let the AOT compiler handle them. - Currently we skip ATen decompositon for all ops, and let the Metal backend handle them. - """ - do_not_decompose = set() - - for node in ep.graph.nodes: - if node.op == "call_function" and isinstance( - node.target, torch._ops.OpOverload - ): - do_not_decompose.add(node.target) - return list(do_not_decompose), None + super().__init__(MetalBackend.__name__, compile_spec) diff --git a/backends/cuda/TARGETS b/backends/cuda/TARGETS index 22987a728ca..94af87bbaed 100644 --- a/backends/cuda/TARGETS +++ b/backends/cuda/TARGETS @@ -29,7 +29,6 @@ runtime.python_library( ], deps = [ "//caffe2:torch", - "//executorch/exir/backend:partitioner", - "//executorch/exir/backend:utils", + "//executorch/backends/aoti:aoti_partitioner", ], ) diff --git a/backends/cuda/cuda_partitioner.py b/backends/cuda/cuda_partitioner.py index 64df7b7dcb2..e8f1276d5eb 100644 --- a/backends/cuda/cuda_partitioner.py +++ b/backends/cuda/cuda_partitioner.py @@ -4,74 +4,22 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable, Dict, final, List, Optional, Tuple +from typing import final, List -import torch +from executorch.backends.aoti.aoti_partitioner import AotiPartitioner from executorch.backends.cuda.cuda_backend import CudaBackend # usort: skip from executorch.exir._warnings import experimental from executorch.exir.backend.compile_spec_schema import CompileSpec -from executorch.exir.backend.partitioner import ( - DelegationSpec, - Partitioner, - PartitionResult, -) -from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer -from torch.export.exported_program import ExportedProgram @final @experimental( "This API and all of cuda backend related functionality are experimental." ) -class CudaPartitioner(Partitioner): +class CudaPartitioner(AotiPartitioner): """ - CUDA partitioner for AOTInductor backend integration. - - This partitioner creates a single partition containing all operators from the input graph. - It skips core ATen decomposition, allowing the CUDA backend to handle decomposition using - AOTInductor's CUDA-specific decomposition table. - - Only operators that cannot be handled by the aoti-cuda library will be excluded from - the partition and fall back to ExecuTorch's default or custom handling. + CUDA partitioner driven by AOTInductor backend. """ def __init__(self, compile_spec: List[CompileSpec]) -> None: - self.delegation_spec = DelegationSpec(CudaBackend.__name__, compile_spec) - - def partition(self, exported_program: ExportedProgram) -> PartitionResult: - """ - Fully delegate the graph to AOTInductor by tagging all nodes as a single partition. - """ - - partition_tags: Dict[str, DelegationSpec] = {} - tag = "tag0" - - for node in exported_program.graph.nodes: - if node.op != "call_function": - continue - node.meta["delegation_tag"] = tag - - partition_tags[tag] = self.delegation_spec - - tag_constant_data(exported_program) - tag_mutated_buffer(exported_program) - - return PartitionResult( - tagged_exported_program=exported_program, partition_tags=partition_tags - ) - - def ops_to_not_decompose( - self, ep: ExportedProgram - ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: - """ - Return a list of operations that should not be decomposed and let the AOT compiler handle them. - Currently we skip ATen decompositon for all ops, and let the cuda backend handle them. - """ - do_not_decompose = set() - - for node in ep.graph.nodes: - if node.op == "call_function" and isinstance( - node.target, torch._ops.OpOverload - ): - do_not_decompose.add(node.target) - return list(do_not_decompose), None + super().__init__(CudaBackend.__name__, compile_spec)