Skip to content

Commit 87e9c16

Browse files
authored
cuda export supported
Differential Revision: D82987410 Pull Request resolved: pytorch#14574
1 parent 79e9224 commit 87e9c16

File tree

6 files changed

+465
-4
lines changed

6 files changed

+465
-4
lines changed

backends/cuda/TARGETS

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,22 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
22

33
oncall("executorch")
44

5+
runtime.python_library(
6+
name = "cuda_backend",
7+
srcs = [
8+
"cuda_backend.py",
9+
],
10+
visibility = [
11+
"//executorch/...",
12+
],
13+
deps = [
14+
"//caffe2:torch",
15+
"//executorch/exir/_serialize:lib",
16+
"//executorch/exir/backend:backend_details",
17+
"//executorch/exir/backend:compile_spec_schema",
18+
],
19+
)
20+
521
runtime.python_library(
622
name = "cuda_partitioner",
723
srcs = [

backends/cuda/cuda_backend.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import contextlib
8+
import os
9+
import typing
10+
11+
from typing import Any, Dict, final, List, Optional, Set
12+
13+
import torch
14+
from executorch.exir._serialize._named_data_store import NamedDataStore
15+
from executorch.exir._warnings import experimental
16+
from executorch.exir.backend.backend_details import (
17+
BackendDetails,
18+
ExportedProgram,
19+
PreprocessResult,
20+
)
21+
from executorch.exir.backend.compile_spec_schema import CompileSpec
22+
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
23+
from torch.export.passes import move_to_device_pass
24+
25+
26+
# exist fallback operators in et namespace;
27+
supported_fallback_kernels: Dict[str, Any] = {}
28+
29+
# required fallback kernels but not supported
30+
missing_fallback_kernels: Set[str] = set()
31+
32+
33+
# context manager for non-fallback guarantee
34+
# it will raise exception when generating fallback kernels during aoti compile
35+
@contextlib.contextmanager
36+
def collect_unsupported_fallback_kernels():
37+
original_generate_c_shim_extern_kernel_call = (
38+
CppWrapperCpu.generate_c_shim_extern_kernel_call
39+
)
40+
original_generate_fallback_kernel_with_runtime_lookup_aot = (
41+
CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot
42+
)
43+
44+
def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels(
45+
self,
46+
kernel: str,
47+
args: list[str],
48+
device: str,
49+
*,
50+
debug_args: Optional[list[str]] = None,
51+
):
52+
if kernel not in supported_fallback_kernels:
53+
missing_fallback_kernels.add(kernel)
54+
55+
original_generate_c_shim_extern_kernel_call(
56+
self, kernel, args, device, debug_args=debug_args
57+
)
58+
59+
def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels(
60+
self,
61+
op_overload,
62+
raw_args,
63+
output_args,
64+
raw_outputs,
65+
):
66+
# Extract kernel name for collection
67+
kernel_name = getattr(op_overload, "_name", str(op_overload))
68+
if kernel_name not in supported_fallback_kernels:
69+
missing_fallback_kernels.add(kernel_name)
70+
71+
original_generate_fallback_kernel_with_runtime_lookup_aot(
72+
self, op_overload, raw_args, output_args, raw_outputs
73+
)
74+
75+
CppWrapperCpu.generate_c_shim_extern_kernel_call = (
76+
generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels
77+
)
78+
CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = (
79+
generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels
80+
)
81+
try:
82+
yield
83+
finally:
84+
CppWrapperCpu.generate_c_shim_extern_kernel_call = (
85+
original_generate_c_shim_extern_kernel_call
86+
)
87+
CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = (
88+
original_generate_fallback_kernel_with_runtime_lookup_aot
89+
)
90+
91+
92+
@final
93+
@experimental(
94+
"This API and all of cuda backend related functionality are experimental."
95+
)
96+
class CudaBackend(BackendDetails):
97+
"""
98+
CudaBackend is a backend that compiles a model to run on CUDA devices. It uses the AOTInductor compiler to generate
99+
optimized CUDA kernels for the model's operators with libtorch-free. The compiled model can be executed on CUDA devices
100+
using the Executorch runtime.
101+
"""
102+
103+
@staticmethod
104+
def preprocess(
105+
edge_program: ExportedProgram,
106+
compile_specs: List[CompileSpec],
107+
) -> PreprocessResult:
108+
# Move the edge_program from CPU to CUDA for aoti compile
109+
cuda_edge_program = move_to_device_pass(edge_program, "cuda")
110+
111+
edge_program_module = cuda_edge_program.module()
112+
113+
# Grab all input placeholders from the graph
114+
user_input_names = cuda_edge_program.graph_signature.user_inputs
115+
user_input_placeholders = []
116+
for node in cuda_edge_program.graph.nodes:
117+
if node.op == "placeholder" and node.name in user_input_names:
118+
user_input_placeholders.append(node.meta["val"])
119+
120+
# Create pseudo user inputs using torch.randn and metadata from input placeholders
121+
faked_user_inputs = []
122+
for placeholder in user_input_placeholders:
123+
if isinstance(placeholder, torch.Tensor):
124+
# Generate fake input with same shape and dtype, on CUDA
125+
fake_input = torch.randn(
126+
placeholder.shape, dtype=placeholder.dtype, device="cuda"
127+
)
128+
faked_user_inputs.append(fake_input)
129+
130+
faked_user_inputs = tuple(faked_user_inputs)
131+
132+
options: dict[str, typing.Any] = {
133+
# Embed CUDA kernel binaries directly into the compiled shared object
134+
"aot_inductor.embed_kernel_binary": True,
135+
# Do not link against the full PyTorch/libtorch library
136+
"aot_inductor.link_libtorch": False,
137+
# Package model constants and other generated files directly in the shared object (.so) file
138+
"aot_inductor.package_constants_in_so": True,
139+
# Enable maximum automatic tuning for optimal performance
140+
"max_autotune": True,
141+
# Use TRITON for GEMM (General Matrix Multiply) operations tuning only to avoid using operators in libtorch
142+
"max_autotune_gemm_backends": "TRITON",
143+
# Use TRITON backend for convolution operations tuning only to avoid using operators in libtorch
144+
"max_autotune_conv_backends": "TRITON",
145+
}
146+
147+
with collect_unsupported_fallback_kernels():
148+
so_path = torch._inductor.aot_compile(edge_program_module, faked_user_inputs, options=options) # type: ignore[arg-type]
149+
if len(missing_fallback_kernels) > 0:
150+
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
151+
raise RuntimeError(
152+
f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n"
153+
"Please add them to the AOTI backend."
154+
)
155+
156+
# pyre-ignorep[6]: Incompatible parameter type
157+
with open(so_path, "rb") as f:
158+
so_data = f.read()
159+
160+
named_data_store = NamedDataStore()
161+
named_data_store.add_named_data("so_blob", so_data, 1, "aoti_cuda_blob")
162+
163+
# Clean up the generated so file; it has been packaged into the NamdeDataStore
164+
# pyre-ignorep[6]: Incompatible parameter type
165+
os.remove(so_path)
166+
167+
return PreprocessResult(
168+
processed_bytes=b"",
169+
debug_handle_map={},
170+
data_store_output=named_data_store.get_named_data_store_output(),
171+
)

backends/cuda/cuda_partitioner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from typing import Callable, Dict, final, List, Optional, Tuple
88

99
import torch
10+
from executorch.backends.cuda.cuda_backend import CudaBackend # usort: skip
11+
from executorch.exir._warnings import experimental
1012
from executorch.exir.backend.compile_spec_schema import CompileSpec
1113
from executorch.exir.backend.partitioner import (
1214
DelegationSpec,
@@ -18,6 +20,9 @@
1820

1921

2022
@final
23+
@experimental(
24+
"This API and all of cuda backend related functionality are experimental."
25+
)
2126
class CudaPartitioner(Partitioner):
2227
"""
2328
CUDA partitioner for AOTInductor backend integration.
@@ -31,7 +36,7 @@ class CudaPartitioner(Partitioner):
3136
"""
3237

3338
def __init__(self, compile_spec: List[CompileSpec]) -> None:
34-
self.delegation_spec = DelegationSpec("CudaBackend", compile_spec)
39+
self.delegation_spec = DelegationSpec(CudaBackend.__name__, compile_spec)
3540

3641
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
3742
"""

backends/cuda/tests/TARGETS

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,28 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
22
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
3+
load("@fbcode_macros//build_defs:python_unittest_remote_gpu.bzl", "python_unittest_remote_gpu")
34

45
oncall("executorch")
56

7+
python_unittest_remote_gpu(
8+
name = "test_cuda_export",
9+
srcs = [
10+
"test_cuda_export.py",
11+
],
12+
visibility = [
13+
"//executorch/...",
14+
],
15+
deps = [
16+
"//caffe2:torch",
17+
"//executorch/backends/cuda:cuda_backend",
18+
"//executorch/backends/cuda:cuda_partitioner",
19+
"//executorch/exir:lib",
20+
"//executorch/exir/backend:backend_api",
21+
"//executorch/exir/backend:compile_spec_schema",
22+
],
23+
keep_gpu_sections = True,
24+
)
25+
626
python_unittest(
727
name = "test_cuda_partitioner",
828
srcs = [
@@ -14,6 +34,7 @@ python_unittest(
1434
deps = [
1535
"//caffe2:torch",
1636
"//executorch/backends/cuda:cuda_partitioner",
37+
"//executorch/backends/cuda:cuda_backend",
1738
"//executorch/exir:lib",
1839
"//executorch/exir/backend:compile_spec_schema",
1940
],

0 commit comments

Comments
 (0)