Skip to content

Commit 44f3740

Browse files
authored
cuda partioner supported
Differential Revision: D82987193 Pull Request resolved: #14477
1 parent 24ead6b commit 44f3740

File tree

6 files changed

+260
-0
lines changed

6 files changed

+260
-0
lines changed

backends/cuda/TARGETS

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_library(
6+
name = "cuda_partitioner",
7+
srcs = [
8+
"cuda_partitioner.py",
9+
],
10+
visibility = [
11+
"//executorch/...",
12+
],
13+
deps = [
14+
"//caffe2:torch",
15+
"//executorch/exir/backend:partitioner",
16+
"//executorch/exir/backend:utils",
17+
],
18+
)

backends/cuda/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.

backends/cuda/cuda_partitioner.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Callable, Dict, final, List, Optional, Tuple
8+
9+
import torch
10+
from executorch.exir.backend.compile_spec_schema import CompileSpec
11+
from executorch.exir.backend.partitioner import (
12+
DelegationSpec,
13+
Partitioner,
14+
PartitionResult,
15+
)
16+
from executorch.exir.backend.utils import tag_constant_data
17+
from torch.export.exported_program import ExportedProgram
18+
19+
20+
@final
21+
class CudaPartitioner(Partitioner):
22+
"""
23+
CUDA partitioner for AOTInductor backend integration.
24+
25+
This partitioner creates a single partition containing all operators from the input graph.
26+
It skips core ATen decomposition, allowing the CUDA backend to handle decomposition using
27+
AOTInductor's CUDA-specific decomposition table.
28+
29+
Only operators that cannot be handled by the aoti-cuda library will be excluded from
30+
the partition and fall back to ExecuTorch's default or custom handling.
31+
"""
32+
33+
def __init__(self, compile_spec: List[CompileSpec]) -> None:
34+
self.delegation_spec = DelegationSpec("CudaBackend", compile_spec)
35+
36+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
37+
"""
38+
Fully delegate the graph to AOTInductor by tagging all nodes as a single partition.
39+
"""
40+
41+
partition_tags: Dict[str, DelegationSpec] = {}
42+
for node in exported_program.graph.nodes:
43+
if node.op != "call_function":
44+
continue
45+
tag = "tag0"
46+
node.meta["delegation_tag"] = tag
47+
partition_tags[tag] = self.delegation_spec
48+
49+
tag_constant_data(exported_program)
50+
51+
return PartitionResult(
52+
tagged_exported_program=exported_program, partition_tags=partition_tags
53+
)
54+
55+
def ops_to_not_decompose(
56+
self, ep: ExportedProgram
57+
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
58+
"""
59+
Return a list of operations that should not be decomposed and let the AOT compiler handle them.
60+
Currently we skip ATen decompositon for all ops, and let the cuda backend handle them.
61+
"""
62+
do_not_decompose = set()
63+
64+
for node in ep.graph.nodes:
65+
if node.op == "call_function" and isinstance(
66+
node.target, torch._ops.OpOverload
67+
):
68+
do_not_decompose.add(node.target)
69+
return list(do_not_decompose), None

backends/cuda/tests/TARGETS

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
3+
4+
oncall("executorch")
5+
6+
python_unittest(
7+
name = "test_cuda_partitioner",
8+
srcs = [
9+
"test_cuda_partitioner.py",
10+
],
11+
visibility = [
12+
"//executorch/...",
13+
],
14+
deps = [
15+
"//caffe2:torch",
16+
"//executorch/backends/cuda:cuda_partitioner",
17+
"//executorch/exir:lib",
18+
"//executorch/exir/backend:compile_spec_schema",
19+
],
20+
)

backends/cuda/tests/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
from typing import Tuple
9+
10+
import torch
11+
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
12+
from executorch.exir.backend.compile_spec_schema import CompileSpec
13+
from executorch.exir.backend.partitioner import PartitionResult
14+
from torch.export import export
15+
16+
17+
class TestCudaPartitioner(unittest.TestCase):
18+
"""
19+
Test CUDA partitioner functionality.
20+
21+
After CUDA partitioning, there should be exactly one partitioned graph that contains
22+
all operators from the input graph. This means all operators should be tagged with
23+
the same delegation tag, indicating they will all be executed by the CUDA backend.
24+
"""
25+
26+
def _get_partition_result(
27+
self, module: torch.nn.Module, inputs: Tuple[torch.Tensor, ...]
28+
) -> PartitionResult:
29+
"""Helper method to get partition result for a given module."""
30+
# Export the model
31+
exported_program = export(module, inputs, strict=True)
32+
33+
# Create partitioner and compile specs
34+
compile_specs = [CompileSpec("cuda_compile_options", b"")]
35+
partitioner = CudaPartitioner(compile_specs)
36+
37+
# Get partition result
38+
partition_result = partitioner.partition(exported_program)
39+
40+
# Verify partition result structure
41+
self.assertIsNotNone(partition_result)
42+
self.assertTrue(hasattr(partition_result, "tagged_exported_program"))
43+
self.assertTrue(hasattr(partition_result, "partition_tags"))
44+
45+
return partition_result
46+
47+
def _check_fully_partitioned(self, partition_result: PartitionResult) -> bool:
48+
"""Check if the graph is fully partitioned (all operators have the same tag)."""
49+
tagged_nodes = []
50+
untagged_ops = []
51+
52+
for node in partition_result.tagged_exported_program.graph.nodes:
53+
if node.op == "call_function":
54+
if hasattr(node, "meta") and "delegation_tag" in node.meta:
55+
tagged_nodes.append(node)
56+
else:
57+
untagged_ops.append(node)
58+
59+
# Check if we have any tagged nodes
60+
if not tagged_nodes:
61+
return False
62+
63+
# Check if all tagged nodes have the same tag
64+
first_tag = tagged_nodes[0].meta["delegation_tag"]
65+
all_same_tag = all(
66+
node.meta.get("delegation_tag") == first_tag for node in tagged_nodes
67+
)
68+
69+
# Should have no untagged operations for full partitioning
70+
fully_partitioned = len(untagged_ops) == 0 and all_same_tag
71+
72+
return fully_partitioned
73+
74+
def test_simple_add_partition(self):
75+
"""
76+
Test that CUDA partitioner creates exactly one partition containing all operators.
77+
Simple element-wise addition should result in a single graph with all ops tagged identically.
78+
"""
79+
80+
class AddModule(torch.nn.Module):
81+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
82+
return x + y
83+
84+
module = AddModule()
85+
inputs = (torch.randn(3, 4), torch.randn(3, 4))
86+
87+
partition_result = self._get_partition_result(module, inputs)
88+
fully_partitioned = self._check_fully_partitioned(partition_result)
89+
90+
self.assertTrue(
91+
fully_partitioned,
92+
"Graph should be fully partitioned with all operators having the same tag",
93+
)
94+
95+
def test_conv2d_partition(self):
96+
"""
97+
Test that CUDA partitioner creates exactly one partition containing all operators.
98+
Conv2D operation should result in a single graph with all ops tagged identically.
99+
"""
100+
101+
class Conv2dModule(torch.nn.Module):
102+
def __init__(self):
103+
super().__init__()
104+
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1)
105+
106+
def forward(self, x: torch.Tensor) -> torch.Tensor:
107+
return self.conv(x)
108+
109+
module = Conv2dModule()
110+
inputs = (torch.randn(1, 3, 32, 32),)
111+
112+
partition_result = self._get_partition_result(module, inputs)
113+
fully_partitioned = self._check_fully_partitioned(partition_result)
114+
115+
self.assertTrue(
116+
fully_partitioned,
117+
"Graph should be fully partitioned with all operators having the same tag",
118+
)
119+
120+
def test_linear_partition(self):
121+
"""
122+
Test that CUDA partitioner creates exactly one partition containing all operators.
123+
Linear layer operation should result in a single graph with all ops tagged identically.
124+
"""
125+
126+
class LinearModule(torch.nn.Module):
127+
def __init__(self):
128+
super().__init__()
129+
self.linear = torch.nn.Linear(128, 64)
130+
131+
def forward(self, x: torch.Tensor) -> torch.Tensor:
132+
return self.linear(x)
133+
134+
module = LinearModule()
135+
inputs = (torch.randn(8, 128),)
136+
137+
partition_result = self._get_partition_result(module, inputs)
138+
fully_partitioned = self._check_fully_partitioned(partition_result)
139+
140+
self.assertTrue(
141+
fully_partitioned,
142+
"Graph should be fully partitioned with all operators having the same tag",
143+
)

0 commit comments

Comments
 (0)