Skip to content

Commit e690b0a

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
cuda export supported
Summary: this diff introuce the cuda backend that compiles the partitioned model graph to run on CUDA devices. It uses the AOTInductor compiler to generate optimized CUDA kernels for the model's operators with libtorch-free. The compiled model can be executed on CUDA devices using the Executorch runtime. Differential Revision: D82987410
1 parent 152afbe commit e690b0a

File tree

4 files changed

+408
-0
lines changed

4 files changed

+408
-0
lines changed

backends/cuda/TARGETS

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,22 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
22

33
oncall("executorch")
44

5+
runtime.python_library(
6+
name = "cuda_backend",
7+
srcs = [
8+
"cuda_backend.py",
9+
],
10+
visibility = [
11+
"//executorch/...",
12+
],
13+
deps = [
14+
"//caffe2:torch",
15+
"//executorch/exir/_serialize:lib",
16+
"//executorch/exir/backend:backend_details",
17+
"//executorch/exir/backend:compile_spec_schema",
18+
],
19+
)
20+
521
runtime.python_library(
622
name = "cuda_partitioner",
723
srcs = [

backends/cuda/cuda_backend.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import contextlib
8+
import os
9+
import tempfile
10+
import typing
11+
12+
from typing import Any, Dict, final, List, Optional, Set
13+
14+
import torch
15+
from executorch.exir._serialize._named_data_store import NamedDataStore
16+
from executorch.exir.backend.backend_details import (
17+
BackendDetails,
18+
ExportedProgram,
19+
PreprocessResult,
20+
)
21+
from executorch.exir.backend.compile_spec_schema import CompileSpec
22+
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
23+
from torch.export.passes import move_to_device_pass
24+
25+
26+
# exist fallback operators in et namespace;
27+
supported_fallback_kernels: Dict[str, Any] = {}
28+
29+
# required fallback kernels but not supported
30+
missing_fallback_kernels: Set[str] = set()
31+
32+
33+
# context manager for non-fallback guarantee
34+
# it will raise exception when generating fallback kernels during aoti compile
35+
@contextlib.contextmanager
36+
def collect_unsupported_fallback_kernels():
37+
original_generate_c_shim_extern_kernel_call = (
38+
CppWrapperCpu.generate_c_shim_extern_kernel_call
39+
)
40+
41+
def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels(
42+
self,
43+
kernel: str,
44+
args: list[str],
45+
device: str,
46+
*,
47+
debug_args: Optional[list[str]] = None,
48+
):
49+
if kernel not in supported_fallback_kernels:
50+
missing_fallback_kernels.add(kernel)
51+
52+
original_generate_c_shim_extern_kernel_call(
53+
self, kernel, args, device, debug_args=debug_args
54+
)
55+
56+
CppWrapperCpu.generate_c_shim_extern_kernel_call = (
57+
generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels
58+
)
59+
try:
60+
yield
61+
finally:
62+
CppWrapperCpu.generate_c_shim_extern_kernel_call = (
63+
original_generate_c_shim_extern_kernel_call
64+
)
65+
66+
67+
@final
68+
class CudaBackend(BackendDetails):
69+
"""
70+
CudaBackend is a backend that compiles a model to run on CUDA devices. It uses the AOTInductor compiler to generate
71+
optimized CUDA kernels for the model's operators with libtorch-free. The compiled model can be executed on CUDA devices
72+
using the Executorch runtime.
73+
"""
74+
75+
@staticmethod
76+
def preprocess(
77+
edge_program: ExportedProgram,
78+
compile_specs: List[CompileSpec],
79+
) -> PreprocessResult:
80+
81+
named_data_store = NamedDataStore()
82+
83+
# Move the edge_program from CPU to CUDA for aoti compile
84+
cuda_edge_program = move_to_device_pass(edge_program, "cuda")
85+
86+
edge_program_module = cuda_edge_program.module()
87+
args, kwargs = cuda_edge_program.example_inputs
88+
89+
# Create a temporary file path for the compiled shared library output
90+
output_path = tempfile.mktemp(suffix=".so", prefix="aoti_")
91+
92+
options: dict[str, typing.Any] = {
93+
# Embed CUDA kernel binaries directly into the compiled shared object
94+
"aot_inductor.embed_kernel_binary": True,
95+
# Do not link against the full PyTorch/libtorch library
96+
"aot_inductor.link_libtorch": False,
97+
# Package model constants and other generated files directly in the shared object (.so) file
98+
"aot_inductor.package_constants_in_so": True,
99+
# Specify the output file path for the compiled shared object
100+
"aot_inductor.output_path": output_path,
101+
# Enable maximum automatic tuning for optimal performance
102+
"max_autotune": True,
103+
# Use TRITON for GEMM (General Matrix Multiply) operations tuning only to avoid using operators in libtorch
104+
"max_autotune_gemm_backends": "TRITON",
105+
# Use TRITON backend for convolution operations tuning only to avoid using operators in libtorch
106+
"max_autotune_conv_backends": "TRITON",
107+
}
108+
109+
with collect_unsupported_fallback_kernels():
110+
_ = torch._inductor.aot_compile(edge_program_module, args, kwargs, options=options) # type: ignore[arg-type]
111+
if len(missing_fallback_kernels) > 0:
112+
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
113+
raise RuntimeError(
114+
f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n"
115+
"Please add them to the AOTI backend."
116+
)
117+
118+
with open(output_path, "rb") as f:
119+
so_data = f.read()
120+
121+
named_data_store.add_named_data("so_blob", so_data, 1, "aoti_cuda_blob")
122+
123+
# Clean up the temporary output file
124+
os.remove(output_path)
125+
126+
return PreprocessResult(
127+
processed_bytes=b"",
128+
debug_handle_map={},
129+
data_store_output=named_data_store.get_named_data_store_output(),
130+
)

backends/cuda/tests/TARGETS

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,24 @@ load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
33

44
oncall("executorch")
55

6+
python_unittest(
7+
name = "test_cuda_export",
8+
srcs = [
9+
"test_cuda_export.py",
10+
],
11+
visibility = [
12+
"//executorch/...",
13+
],
14+
deps = [
15+
"//caffe2:torch",
16+
"//executorch/backends/cuda:cuda_backend",
17+
"//executorch/backends/cuda:cuda_partitioner",
18+
"//executorch/exir:lib",
19+
"//executorch/exir/backend:backend_api",
20+
"//executorch/exir/backend:compile_spec_schema",
21+
],
22+
)
23+
624
python_unittest(
725
name = "test_cuda_partitioner",
826
srcs = [
@@ -14,6 +32,7 @@ python_unittest(
1432
deps = [
1533
"//caffe2:torch",
1634
"//executorch/backends/cuda:cuda_partitioner",
35+
"//executorch/backends/cuda:cuda_backend",
1736
"//executorch/exir:lib",
1837
"//executorch/exir/backend:compile_spec_schema",
1938
],

0 commit comments

Comments
 (0)