Skip to content

Commit b33c284

Browse files
committed
[aoti-backend-consolidation 1/3] partitioners
Differential Revision: [D85700449](https://our.internmc.facebook.com/intern/diff/D85700449/) [ghstack-poisoned]
1 parent 37a65b5 commit b33c284

File tree

3 files changed

+92
-114
lines changed

3 files changed

+92
-114
lines changed

backends/aoti/aoti_partitioner.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Callable, Dict, List, Optional, Tuple
8+
9+
import torch
10+
from executorch.exir._warnings import experimental
11+
from executorch.exir.backend.compile_spec_schema import CompileSpec
12+
from executorch.exir.backend.partitioner import (
13+
DelegationSpec,
14+
Partitioner,
15+
PartitionResult,
16+
)
17+
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
18+
from torch.export.exported_program import ExportedProgram
19+
20+
21+
@experimental(
22+
"This API and all of cuda backend related functionality are experimental."
23+
)
24+
class AotiPartitioner(Partitioner):
25+
"""
26+
Base partitioner for AOTInductor-driven backend integration.
27+
28+
This partitioner creates a single partition containing all operators from the input graph.
29+
It skips core ATen decomposition, allowing the backend to handle decomposition using
30+
AOTInductor's backend-specific decomposition table.
31+
32+
Only operators that cannot be handled by the aoti library will be excluded from
33+
the partition and fall back to ExecuTorch's default or custom handling.
34+
"""
35+
36+
def __init__(self, backend_name: str, compile_spec: List[CompileSpec]) -> None:
37+
"""
38+
Initialize the AOTI partitioner.
39+
40+
Args:
41+
backend_name: The name of the backend (e.g., "CudaBackend", "MetalBackend")
42+
compile_spec: List of compilation specifications
43+
"""
44+
self.delegation_spec = DelegationSpec(backend_name, compile_spec)
45+
46+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
47+
"""
48+
Fully delegate the graph to AOTInductor by tagging all nodes as a single partition.
49+
"""
50+
51+
partition_tags: Dict[str, DelegationSpec] = {}
52+
tag = "tag0"
53+
54+
for node in exported_program.graph.nodes:
55+
if node.op != "call_function":
56+
continue
57+
node.meta["delegation_tag"] = tag
58+
59+
partition_tags[tag] = self.delegation_spec
60+
61+
tag_constant_data(exported_program)
62+
tag_mutated_buffer(exported_program)
63+
64+
return PartitionResult(
65+
tagged_exported_program=exported_program, partition_tags=partition_tags
66+
)
67+
68+
def ops_to_not_decompose(
69+
self, ep: ExportedProgram
70+
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
71+
"""
72+
Return a list of operations that should not be decomposed and let the AOT compiler handle them.
73+
Currently we skip ATen decompositon for all ops, and let the backend handle them.
74+
"""
75+
do_not_decompose = set()
76+
77+
for node in ep.graph.nodes:
78+
if node.op == "call_function" and isinstance(
79+
node.target, torch._ops.OpOverload
80+
):
81+
do_not_decompose.add(node.target)
82+
return list(do_not_decompose), None

backends/apple/metal/metal_partitioner.py

Lines changed: 5 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,74 +4,22 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Callable, Dict, final, List, Optional, Tuple
7+
from typing import final, List
88

9-
import torch
9+
from executorch.backends.aoti.aoti_partitioner import AotiPartitioner
1010
from executorch.backends.apple.metal.metal_backend import MetalBackend # usort: skip
1111
from executorch.exir._warnings import experimental
1212
from executorch.exir.backend.compile_spec_schema import CompileSpec
13-
from executorch.exir.backend.partitioner import (
14-
DelegationSpec,
15-
Partitioner,
16-
PartitionResult,
17-
)
18-
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
19-
from torch.export.exported_program import ExportedProgram
2013

2114

2215
@final
2316
@experimental(
2417
"This API and all of Metal backend related functionality are experimental."
2518
)
26-
class MetalPartitioner(Partitioner):
19+
class MetalPartitioner(AotiPartitioner):
2720
"""
28-
Metal partitioner for AOTInductor backend integration.
29-
30-
This partitioner creates a single partition containing all operators from the input graph.
31-
It skips core ATen decomposition, allowing the Metal backend to handle decomposition using
32-
AOTInductor's MPS-specific decomposition table.
33-
34-
Only operators that cannot be handled by the aoti-mps library will be excluded from
35-
the partition and fall back to ExecuTorch's default or custom handling.
21+
Metal partitioner driven by AOTInductor backend.
3622
"""
3723

3824
def __init__(self, compile_spec: List[CompileSpec]) -> None:
39-
self.delegation_spec = DelegationSpec(MetalBackend.__name__, compile_spec)
40-
41-
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
42-
"""
43-
Fully delegate the graph to AOTInductor by tagging all nodes as a single partition.
44-
"""
45-
46-
partition_tags: Dict[str, DelegationSpec] = {}
47-
tag = "tag0"
48-
49-
for node in exported_program.graph.nodes:
50-
if node.op != "call_function":
51-
continue
52-
node.meta["delegation_tag"] = tag
53-
54-
partition_tags[tag] = self.delegation_spec
55-
56-
tag_constant_data(exported_program)
57-
tag_mutated_buffer(exported_program)
58-
59-
return PartitionResult(
60-
tagged_exported_program=exported_program, partition_tags=partition_tags
61-
)
62-
63-
def ops_to_not_decompose(
64-
self, ep: ExportedProgram
65-
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
66-
"""
67-
Return a list of operations that should not be decomposed and let the AOT compiler handle them.
68-
Currently we skip ATen decompositon for all ops, and let the Metal backend handle them.
69-
"""
70-
do_not_decompose = set()
71-
72-
for node in ep.graph.nodes:
73-
if node.op == "call_function" and isinstance(
74-
node.target, torch._ops.OpOverload
75-
):
76-
do_not_decompose.add(node.target)
77-
return list(do_not_decompose), None
25+
super().__init__(MetalBackend.__name__, compile_spec)

backends/cuda/cuda_partitioner.py

Lines changed: 5 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,74 +4,22 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Callable, Dict, final, List, Optional, Tuple
7+
from typing import final, List
88

9-
import torch
9+
from executorch.backends.aoti.aoti_partitioner import AotiPartitioner
1010
from executorch.backends.cuda.cuda_backend import CudaBackend # usort: skip
1111
from executorch.exir._warnings import experimental
1212
from executorch.exir.backend.compile_spec_schema import CompileSpec
13-
from executorch.exir.backend.partitioner import (
14-
DelegationSpec,
15-
Partitioner,
16-
PartitionResult,
17-
)
18-
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
19-
from torch.export.exported_program import ExportedProgram
2013

2114

2215
@final
2316
@experimental(
2417
"This API and all of cuda backend related functionality are experimental."
2518
)
26-
class CudaPartitioner(Partitioner):
19+
class CudaPartitioner(AotiPartitioner):
2720
"""
28-
CUDA partitioner for AOTInductor backend integration.
29-
30-
This partitioner creates a single partition containing all operators from the input graph.
31-
It skips core ATen decomposition, allowing the CUDA backend to handle decomposition using
32-
AOTInductor's CUDA-specific decomposition table.
33-
34-
Only operators that cannot be handled by the aoti-cuda library will be excluded from
35-
the partition and fall back to ExecuTorch's default or custom handling.
21+
CUDA partitioner driven by AOTInductor backend.
3622
"""
3723

3824
def __init__(self, compile_spec: List[CompileSpec]) -> None:
39-
self.delegation_spec = DelegationSpec(CudaBackend.__name__, compile_spec)
40-
41-
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
42-
"""
43-
Fully delegate the graph to AOTInductor by tagging all nodes as a single partition.
44-
"""
45-
46-
partition_tags: Dict[str, DelegationSpec] = {}
47-
tag = "tag0"
48-
49-
for node in exported_program.graph.nodes:
50-
if node.op != "call_function":
51-
continue
52-
node.meta["delegation_tag"] = tag
53-
54-
partition_tags[tag] = self.delegation_spec
55-
56-
tag_constant_data(exported_program)
57-
tag_mutated_buffer(exported_program)
58-
59-
return PartitionResult(
60-
tagged_exported_program=exported_program, partition_tags=partition_tags
61-
)
62-
63-
def ops_to_not_decompose(
64-
self, ep: ExportedProgram
65-
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
66-
"""
67-
Return a list of operations that should not be decomposed and let the AOT compiler handle them.
68-
Currently we skip ATen decompositon for all ops, and let the cuda backend handle them.
69-
"""
70-
do_not_decompose = set()
71-
72-
for node in ep.graph.nodes:
73-
if node.op == "call_function" and isinstance(
74-
node.target, torch._ops.OpOverload
75-
):
76-
do_not_decompose.add(node.target)
77-
return list(do_not_decompose), None
25+
super().__init__(CudaBackend.__name__, compile_spec)

0 commit comments

Comments
 (0)