diff --git a/backends/cuda/TARGETS b/backends/cuda/TARGETS new file mode 100644 index 00000000000..f54a95229c6 --- /dev/null +++ b/backends/cuda/TARGETS @@ -0,0 +1,18 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_library( + name = "cuda_partitioner", + srcs = [ + "cuda_partitioner.py", + ], + visibility = [ + "//executorch/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir/backend:partitioner", + "//executorch/exir/backend:utils", + ], +) diff --git a/backends/cuda/__init__.py b/backends/cuda/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/cuda/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/cuda/cuda_partitioner.py b/backends/cuda/cuda_partitioner.py new file mode 100644 index 00000000000..cf22b0dea81 --- /dev/null +++ b/backends/cuda/cuda_partitioner.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Dict, final, List, Optional, Tuple + +import torch +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.backend.utils import tag_constant_data +from torch.export.exported_program import ExportedProgram + + +@final +class CudaPartitioner(Partitioner): + """ + CUDA partitioner for AOTInductor backend integration. + + This partitioner creates a single partition containing all operators from the input graph. + It skips core ATen decomposition, allowing the CUDA backend to handle decomposition using + AOTInductor's CUDA-specific decomposition table. + + Only operators that cannot be handled by the aoti-cuda library will be excluded from + the partition and fall back to ExecuTorch's default or custom handling. + """ + + def __init__(self, compile_spec: List[CompileSpec]) -> None: + self.delegation_spec = DelegationSpec("CudaBackend", compile_spec) + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + """ + Fully delegate the graph to AOTInductor by tagging all nodes as a single partition. + """ + + partition_tags: Dict[str, DelegationSpec] = {} + for node in exported_program.graph.nodes: + if node.op != "call_function": + continue + tag = "tag0" + node.meta["delegation_tag"] = tag + partition_tags[tag] = self.delegation_spec + + tag_constant_data(exported_program) + + return PartitionResult( + tagged_exported_program=exported_program, partition_tags=partition_tags + ) + + def ops_to_not_decompose( + self, ep: ExportedProgram + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + """ + Return a list of operations that should not be decomposed and let the AOT compiler handle them. + Currently we skip ATen decompositon for all ops, and let the cuda backend handle them. + """ + do_not_decompose = set() + + for node in ep.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.OpOverload + ): + do_not_decompose.add(node.target) + return list(do_not_decompose), None diff --git a/backends/cuda/tests/TARGETS b/backends/cuda/tests/TARGETS new file mode 100644 index 00000000000..c775cf2fec2 --- /dev/null +++ b/backends/cuda/tests/TARGETS @@ -0,0 +1,20 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") + +oncall("executorch") + +python_unittest( + name = "test_cuda_partitioner", + srcs = [ + "test_cuda_partitioner.py", + ], + visibility = [ + "//executorch/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/cuda:cuda_partitioner", + "//executorch/exir:lib", + "//executorch/exir/backend:compile_spec_schema", + ], +) diff --git a/backends/cuda/tests/__init__.py b/backends/cuda/tests/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/cuda/tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/cuda/tests/test_cuda_partitioner.py b/backends/cuda/tests/test_cuda_partitioner.py new file mode 100644 index 00000000000..586d6f14494 --- /dev/null +++ b/backends/cuda/tests/test_cuda_partitioner.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import Tuple + +import torch +from executorch.backends.cuda.cuda_partitioner import CudaPartitioner +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import PartitionResult +from torch.export import export + + +class TestCudaPartitioner(unittest.TestCase): + """ + Test CUDA partitioner functionality. + + After CUDA partitioning, there should be exactly one partitioned graph that contains + all operators from the input graph. This means all operators should be tagged with + the same delegation tag, indicating they will all be executed by the CUDA backend. + """ + + def _get_partition_result( + self, module: torch.nn.Module, inputs: Tuple[torch.Tensor, ...] + ) -> PartitionResult: + """Helper method to get partition result for a given module.""" + # Export the model + exported_program = export(module, inputs, strict=True) + + # Create partitioner and compile specs + compile_specs = [CompileSpec("cuda_compile_options", b"")] + partitioner = CudaPartitioner(compile_specs) + + # Get partition result + partition_result = partitioner.partition(exported_program) + + # Verify partition result structure + self.assertIsNotNone(partition_result) + self.assertTrue(hasattr(partition_result, "tagged_exported_program")) + self.assertTrue(hasattr(partition_result, "partition_tags")) + + return partition_result + + def _check_fully_partitioned(self, partition_result: PartitionResult) -> bool: + """Check if the graph is fully partitioned (all operators have the same tag).""" + tagged_nodes = [] + untagged_ops = [] + + for node in partition_result.tagged_exported_program.graph.nodes: + if node.op == "call_function": + if hasattr(node, "meta") and "delegation_tag" in node.meta: + tagged_nodes.append(node) + else: + untagged_ops.append(node) + + # Check if we have any tagged nodes + if not tagged_nodes: + return False + + # Check if all tagged nodes have the same tag + first_tag = tagged_nodes[0].meta["delegation_tag"] + all_same_tag = all( + node.meta.get("delegation_tag") == first_tag for node in tagged_nodes + ) + + # Should have no untagged operations for full partitioning + fully_partitioned = len(untagged_ops) == 0 and all_same_tag + + return fully_partitioned + + def test_simple_add_partition(self): + """ + Test that CUDA partitioner creates exactly one partition containing all operators. + Simple element-wise addition should result in a single graph with all ops tagged identically. + """ + + class AddModule(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + module = AddModule() + inputs = (torch.randn(3, 4), torch.randn(3, 4)) + + partition_result = self._get_partition_result(module, inputs) + fully_partitioned = self._check_fully_partitioned(partition_result) + + self.assertTrue( + fully_partitioned, + "Graph should be fully partitioned with all operators having the same tag", + ) + + def test_conv2d_partition(self): + """ + Test that CUDA partitioner creates exactly one partition containing all operators. + Conv2D operation should result in a single graph with all ops tagged identically. + """ + + class Conv2dModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + module = Conv2dModule() + inputs = (torch.randn(1, 3, 32, 32),) + + partition_result = self._get_partition_result(module, inputs) + fully_partitioned = self._check_fully_partitioned(partition_result) + + self.assertTrue( + fully_partitioned, + "Graph should be fully partitioned with all operators having the same tag", + ) + + def test_linear_partition(self): + """ + Test that CUDA partitioner creates exactly one partition containing all operators. + Linear layer operation should result in a single graph with all ops tagged identically. + """ + + class LinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(128, 64) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + module = LinearModule() + inputs = (torch.randn(8, 128),) + + partition_result = self._get_partition_result(module, inputs) + fully_partitioned = self._check_fully_partitioned(partition_result) + + self.assertTrue( + fully_partitioned, + "Graph should be fully partitioned with all operators having the same tag", + )