Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions backends/cuda/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

runtime.python_library(
name = "cuda_partitioner",
srcs = [
"cuda_partitioner.py",
],
visibility = [
"//executorch/...",
],
deps = [
"//caffe2:torch",
"//executorch/exir/backend:partitioner",
"//executorch/exir/backend:utils",
],
)
5 changes: 5 additions & 0 deletions backends/cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
69 changes: 69 additions & 0 deletions backends/cuda/cuda_partitioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Dict, final, List, Optional, Tuple

import torch
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.backend.utils import tag_constant_data
from torch.export.exported_program import ExportedProgram


@final
class CudaPartitioner(Partitioner):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mark as experimental

"""
CUDA partitioner for AOTInductor backend integration.

This partitioner creates a single partition containing all operators from the input graph.
It skips core ATen decomposition, allowing the CUDA backend to handle decomposition using
AOTInductor's CUDA-specific decomposition table.

Only operators that cannot be handled by the aoti-cuda library will be excluded from
the partition and fall back to ExecuTorch's default or custom handling.
"""

def __init__(self, compile_spec: List[CompileSpec]) -> None:
self.delegation_spec = DelegationSpec("CudaBackend", compile_spec)

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
"""
Fully delegate the graph to AOTInductor by tagging all nodes as a single partition.
"""

partition_tags: Dict[str, DelegationSpec] = {}
for node in exported_program.graph.nodes:
if node.op != "call_function":
continue
tag = "tag0"
node.meta["delegation_tag"] = tag
partition_tags[tag] = self.delegation_spec

tag_constant_data(exported_program)

return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)

def ops_to_not_decompose(
self, ep: ExportedProgram
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
"""
Return a list of operations that should not be decomposed and let the AOT compiler handle them.
Currently we skip ATen decompositon for all ops, and let the cuda backend handle them.
"""
do_not_decompose = set()

for node in ep.graph.nodes:
if node.op == "call_function" and isinstance(
Copy link
Contributor

@JacobSzwejbka JacobSzwejbka Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can aoti eat control flow hops?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't test that. Let me add a test later.

node.target, torch._ops.OpOverload
):
do_not_decompose.add(node.target)
return list(do_not_decompose), None
20 changes: 20 additions & 0 deletions backends/cuda/tests/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")

oncall("executorch")

python_unittest(
name = "test_cuda_partitioner",
srcs = [
"test_cuda_partitioner.py",
],
visibility = [
"//executorch/...",
],
deps = [
"//caffe2:torch",
"//executorch/backends/cuda:cuda_partitioner",
"//executorch/exir:lib",
"//executorch/exir/backend:compile_spec_schema",
],
)
5 changes: 5 additions & 0 deletions backends/cuda/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
143 changes: 143 additions & 0 deletions backends/cuda/tests/test_cuda_partitioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from typing import Tuple

import torch
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import PartitionResult
from torch.export import export


class TestCudaPartitioner(unittest.TestCase):
"""
Test CUDA partitioner functionality.

After CUDA partitioning, there should be exactly one partitioned graph that contains
all operators from the input graph. This means all operators should be tagged with
the same delegation tag, indicating they will all be executed by the CUDA backend.
"""

def _get_partition_result(
self, module: torch.nn.Module, inputs: Tuple[torch.Tensor, ...]
) -> PartitionResult:
"""Helper method to get partition result for a given module."""
# Export the model
exported_program = export(module, inputs, strict=True)

# Create partitioner and compile specs
compile_specs = [CompileSpec("cuda_compile_options", b"")]
partitioner = CudaPartitioner(compile_specs)

# Get partition result
partition_result = partitioner.partition(exported_program)

# Verify partition result structure
self.assertIsNotNone(partition_result)
self.assertTrue(hasattr(partition_result, "tagged_exported_program"))
self.assertTrue(hasattr(partition_result, "partition_tags"))

return partition_result

def _check_fully_partitioned(self, partition_result: PartitionResult) -> bool:
"""Check if the graph is fully partitioned (all operators have the same tag)."""
tagged_nodes = []
untagged_ops = []

for node in partition_result.tagged_exported_program.graph.nodes:
if node.op == "call_function":
if hasattr(node, "meta") and "delegation_tag" in node.meta:
tagged_nodes.append(node)
else:
untagged_ops.append(node)

# Check if we have any tagged nodes
if not tagged_nodes:
return False

# Check if all tagged nodes have the same tag
first_tag = tagged_nodes[0].meta["delegation_tag"]
all_same_tag = all(
node.meta.get("delegation_tag") == first_tag for node in tagged_nodes
)

# Should have no untagged operations for full partitioning
fully_partitioned = len(untagged_ops) == 0 and all_same_tag

return fully_partitioned

def test_simple_add_partition(self):
"""
Test that CUDA partitioner creates exactly one partition containing all operators.
Simple element-wise addition should result in a single graph with all ops tagged identically.
"""

class AddModule(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y

module = AddModule()
inputs = (torch.randn(3, 4), torch.randn(3, 4))

partition_result = self._get_partition_result(module, inputs)
fully_partitioned = self._check_fully_partitioned(partition_result)

self.assertTrue(
fully_partitioned,
"Graph should be fully partitioned with all operators having the same tag",
)

def test_conv2d_partition(self):
"""
Test that CUDA partitioner creates exactly one partition containing all operators.
Conv2D operation should result in a single graph with all ops tagged identically.
"""

class Conv2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)

module = Conv2dModule()
inputs = (torch.randn(1, 3, 32, 32),)

partition_result = self._get_partition_result(module, inputs)
fully_partitioned = self._check_fully_partitioned(partition_result)

self.assertTrue(
fully_partitioned,
"Graph should be fully partitioned with all operators having the same tag",
)

def test_linear_partition(self):
"""
Test that CUDA partitioner creates exactly one partition containing all operators.
Linear layer operation should result in a single graph with all ops tagged identically.
"""

class LinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(128, 64)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)

module = LinearModule()
inputs = (torch.randn(8, 128),)

partition_result = self._get_partition_result(module, inputs)
fully_partitioned = self._check_fully_partitioned(partition_result)

self.assertTrue(
fully_partitioned,
"Graph should be fully partitioned with all operators having the same tag",
)
Loading