Skip to content

Commit c212b25

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
cuda export supported (#14478)
Summary: this diff introuce the cuda backend that compiles the partitioned model graph to run on CUDA devices. It uses the AOTInductor compiler to generate optimized CUDA kernels for the model's operators with libtorch-free. The compiled model can be executed on CUDA devices using the Executorch runtime. Reviewed By: larryliu0820 Differential Revision: D82987410
1 parent 4246468 commit c212b25

File tree

6 files changed

+425
-4
lines changed

6 files changed

+425
-4
lines changed

backends/cuda/TARGETS

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,22 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
22

33
oncall("executorch")
44

5+
runtime.python_library(
6+
name = "cuda_backend",
7+
srcs = [
8+
"cuda_backend.py",
9+
],
10+
visibility = [
11+
"//executorch/...",
12+
],
13+
deps = [
14+
"//caffe2:torch",
15+
"//executorch/exir/_serialize:lib",
16+
"//executorch/exir/backend:backend_details",
17+
"//executorch/exir/backend:compile_spec_schema",
18+
],
19+
)
20+
521
runtime.python_library(
622
name = "cuda_partitioner",
723
srcs = [

backends/cuda/cuda_backend.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import contextlib
8+
import os
9+
import tempfile
10+
import typing
11+
12+
from typing import Any, Dict, final, List, Optional, Set
13+
14+
import torch
15+
from executorch.exir._serialize._named_data_store import NamedDataStore
16+
from executorch.exir.backend.backend_details import (
17+
BackendDetails,
18+
ExportedProgram,
19+
PreprocessResult,
20+
)
21+
from executorch.exir.backend.compile_spec_schema import CompileSpec
22+
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
23+
from torch.export.passes import move_to_device_pass
24+
25+
26+
# exist fallback operators in et namespace;
27+
supported_fallback_kernels: Dict[str, Any] = {}
28+
29+
# required fallback kernels but not supported
30+
missing_fallback_kernels: Set[str] = set()
31+
32+
33+
# context manager for non-fallback guarantee
34+
# it will raise exception when generating fallback kernels during aoti compile
35+
@contextlib.contextmanager
36+
def collect_unsupported_fallback_kernels():
37+
original_generate_c_shim_extern_kernel_call = (
38+
CppWrapperCpu.generate_c_shim_extern_kernel_call
39+
)
40+
41+
def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels(
42+
self,
43+
kernel: str,
44+
args: list[str],
45+
device: str,
46+
*,
47+
debug_args: Optional[list[str]] = None,
48+
):
49+
if kernel not in supported_fallback_kernels:
50+
missing_fallback_kernels.add(kernel)
51+
52+
original_generate_c_shim_extern_kernel_call(
53+
self, kernel, args, device, debug_args=debug_args
54+
)
55+
56+
CppWrapperCpu.generate_c_shim_extern_kernel_call = (
57+
generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels
58+
)
59+
try:
60+
yield
61+
finally:
62+
CppWrapperCpu.generate_c_shim_extern_kernel_call = (
63+
original_generate_c_shim_extern_kernel_call
64+
)
65+
66+
67+
@final
68+
class CudaBackend(BackendDetails):
69+
"""
70+
CudaBackend is a backend that compiles a model to run on CUDA devices. It uses the AOTInductor compiler to generate
71+
optimized CUDA kernels for the model's operators with libtorch-free. The compiled model can be executed on CUDA devices
72+
using the Executorch runtime.
73+
"""
74+
75+
@staticmethod
76+
def preprocess(
77+
edge_program: ExportedProgram,
78+
compile_specs: List[CompileSpec],
79+
) -> PreprocessResult:
80+
# Move the edge_program from CPU to CUDA for aoti compile
81+
cuda_edge_program = move_to_device_pass(edge_program, "cuda")
82+
83+
edge_program_module = cuda_edge_program.module()
84+
85+
# Step 2: Grab all placeholders from the graph; last n should be user inputs
86+
user_input_names = cuda_edge_program.graph_signature.user_inputs
87+
user_input_placeholders = []
88+
for node in cuda_edge_program.graph.nodes:
89+
if node.op == "placeholder" and node.name in user_input_names:
90+
user_input_placeholders.append(node.meta["val"])
91+
92+
# Step 3: Create pseudo user input using torch.randn and the generated input sizes
93+
faked_user_inputs = []
94+
for placeholder in user_input_placeholders:
95+
if isinstance(placeholder, torch.Tensor):
96+
# Generate fake input with same shape and dtype, on CUDA
97+
fake_input = torch.randn(
98+
placeholder.shape, dtype=placeholder.dtype, device="cuda"
99+
)
100+
faked_user_inputs.append(fake_input)
101+
102+
faked_user_inputs = tuple(faked_user_inputs)
103+
104+
# Create a temporary file path for the compiled shared library output
105+
output_path = tempfile.mktemp(suffix=".so", prefix="aoti_")
106+
107+
options: dict[str, typing.Any] = {
108+
# Embed CUDA kernel binaries directly into the compiled shared object
109+
"aot_inductor.embed_kernel_binary": True,
110+
# Do not link against the full PyTorch/libtorch library
111+
"aot_inductor.link_libtorch": False,
112+
# Package model constants and other generated files directly in the shared object (.so) file
113+
"aot_inductor.package_constants_in_so": True,
114+
# Specify the output file path for the compiled shared object
115+
"aot_inductor.output_path": output_path,
116+
# Enable maximum automatic tuning for optimal performance
117+
"max_autotune": True,
118+
# Use TRITON for GEMM (General Matrix Multiply) operations tuning only to avoid using operators in libtorch
119+
"max_autotune_gemm_backends": "TRITON",
120+
# Use TRITON backend for convolution operations tuning only to avoid using operators in libtorch
121+
"max_autotune_conv_backends": "TRITON",
122+
}
123+
124+
with collect_unsupported_fallback_kernels():
125+
_ = torch._inductor.aot_compile(edge_program_module, faked_user_inputs, options=options) # type: ignore[arg-type]
126+
if len(missing_fallback_kernels) > 0:
127+
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
128+
raise RuntimeError(
129+
f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n"
130+
"Please add them to the AOTI backend."
131+
)
132+
133+
with open(output_path, "rb") as f:
134+
so_data = f.read()
135+
136+
named_data_store = NamedDataStore()
137+
named_data_store.add_named_data("so_blob", so_data, 1, "aoti_cuda_blob")
138+
139+
# Clean up the temporary output file
140+
os.remove(output_path)
141+
142+
return PreprocessResult(
143+
processed_bytes=b"",
144+
debug_handle_map={},
145+
data_store_output=named_data_store.get_named_data_store_output(),
146+
)

backends/cuda/cuda_partitioner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Callable, Dict, final, List, Optional, Tuple
88

99
import torch
10+
from executorch.backends.cuda.cuda_backend import CudaBackend # usort: skip
1011
from executorch.exir.backend.compile_spec_schema import CompileSpec
1112
from executorch.exir.backend.partitioner import (
1213
DelegationSpec,
@@ -31,7 +32,7 @@ class CudaPartitioner(Partitioner):
3132
"""
3233

3334
def __init__(self, compile_spec: List[CompileSpec]) -> None:
34-
self.delegation_spec = DelegationSpec("CudaBackend", compile_spec)
35+
self.delegation_spec = DelegationSpec(CudaBackend.__name__, compile_spec)
3536

3637
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
3738
"""

backends/cuda/tests/TARGETS

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,24 @@ load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
33

44
oncall("executorch")
55

6+
python_unittest(
7+
name = "test_cuda_export",
8+
srcs = [
9+
"test_cuda_export.py",
10+
],
11+
visibility = [
12+
"//executorch/...",
13+
],
14+
deps = [
15+
"//caffe2:torch",
16+
"//executorch/backends/cuda:cuda_backend",
17+
"//executorch/backends/cuda:cuda_partitioner",
18+
"//executorch/exir:lib",
19+
"//executorch/exir/backend:backend_api",
20+
"//executorch/exir/backend:compile_spec_schema",
21+
],
22+
)
23+
624
python_unittest(
725
name = "test_cuda_partitioner",
826
srcs = [
@@ -14,6 +32,7 @@ python_unittest(
1432
deps = [
1533
"//caffe2:torch",
1634
"//executorch/backends/cuda:cuda_partitioner",
35+
"//executorch/backends/cuda:cuda_backend",
1736
"//executorch/exir:lib",
1837
"//executorch/exir/backend:compile_spec_schema",
1938
],

0 commit comments

Comments
 (0)