Skip to content

Commit 7722509

Browse files
refactor cuda_backend.cpp (#14926)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #14904 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/49/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/49/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/47/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/49/orig Differential Revision: [D84135844](https://our.internmc.facebook.com/intern/diff/D84135844/) @diff-train-skip-merge --------- Co-authored-by: gasoonjia <[email protected]>
1 parent 6ea0dc2 commit 7722509

File tree

16 files changed

+140
-196
lines changed

16 files changed

+140
-196
lines changed

backends/aoti/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ target_compile_options(aoti_common PUBLIC -fexceptions -frtti -fPIC)
4040
# Ensure symbols are exported properly
4141
target_link_options(aoti_common PUBLIC -Wl,--export-dynamic)
4242

43-
# Link against PyTorch libraries and standard libraries
43+
# Link against ExecuTorch libraries and standard libraries
4444
target_link_libraries(aoti_common PUBLIC extension_tensor ${CMAKE_DL_LIBS})
4545
executorch_target_link_options_shared_lib(aoti_common)
4646

backends/aoti/common_shims.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,18 @@ int32_t aoti_torch_layout_strided() {
127127
}
128128

129129
// Dtype constants - these return the PyTorch dtype codes
130-
// Currently only float32 is supported, but using robust enum-based approach
131130
int32_t aoti_torch_dtype_float32() {
132131
return 6; // PyTorch's float32 dtype code
133132
}
134133

134+
int32_t aoti_torch_dtype_bfloat16() {
135+
return 15; // PyTorch's bfloat16 dtype code
136+
}
137+
138+
int32_t aoti_torch_dtype_int64() {
139+
return 4; // PyTorch's int64 dtype code
140+
}
141+
135142
// Cleanup functions
136143
void cleanup_tensor_metadata() {
137144
internal::tensor_to_sizes.clear();

backends/aoti/common_shims.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim);
5858
int32_t aoti_torch_device_type_cpu();
5959
int32_t aoti_torch_layout_strided();
6060
int32_t aoti_torch_dtype_float32();
61+
int32_t aoti_torch_dtype_bfloat16();
62+
int32_t aoti_torch_dtype_int64();
6163

6264
// Autograd mode functions
6365
int32_t aoti_torch_grad_mode_is_enabled();

backends/aoti/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def define_common_targets():
5151
link_whole = True,
5252
supports_python_dlopen = True,
5353
visibility = ["@EXECUTORCH_CLIENTS"],
54-
deps = [
54+
exported_deps = [
5555
":common_shims",
5656
":model_container",
5757
],

backends/cuda/CMakeLists.txt

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,7 @@ target_link_options(aoti_cuda PUBLIC -Wl,--export-dynamic)
5555

5656
# Link against CUDA::cudart, common AOTI library, and PyTorch CUDA libraries
5757
target_link_libraries(
58-
aoti_cuda
59-
PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS}
60-
# Link PyTorch libraries for AOTI CUDA functions
61-
${TORCH_LIBRARIES}
58+
aoti_cuda PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS}
6259
)
6360
# If you need other CUDA libraries, link them similarly:
6461
# target_link_libraries(aoti_cuda PUBLIC CUDA::cublas CUDA::cufft ...)

backends/cuda/runtime/TARGETS

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,25 @@ runtime.cxx_library(
3434
("cuda", None, "cuda-lazy"),
3535
],
3636
)
37+
38+
runtime.cxx_library(
39+
name = "cuda_backend",
40+
srcs = [
41+
"cuda_backend.cpp",
42+
],
43+
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
44+
link_whole = True,
45+
supports_python_dlopen = True,
46+
# Constructor needed for backend registration.
47+
compiler_flags = ["-Wno-global-constructors"],
48+
visibility = ["@EXECUTORCH_CLIENTS"],
49+
deps = [
50+
":runtime_shims",
51+
"//executorch/backends/aoti:aoti_common",
52+
"//executorch/runtime/backend:interface",
53+
"//executorch/runtime/core/exec_aten/util:tensor_util",
54+
],
55+
external_deps = [
56+
("cuda", None, "cuda-lazy"),
57+
],
58+
)

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 76 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,14 @@
2626
#include <executorch/backends/cuda/runtime/shims/memory.h>
2727
#include <executorch/backends/cuda/runtime/utils.h>
2828

29-
namespace executorch {
30-
namespace backends {
31-
namespace cuda {
29+
namespace executorch::backends::cuda {
30+
31+
#define LOAD_SYMBOL(name, handle) \
32+
do { \
33+
name = reinterpret_cast<name##Func>(dlsym(handle, #name)); \
34+
ET_CHECK_OR_RETURN_ERROR( \
35+
name != nullptr, AccessFailed, "Failed to load " #name); \
36+
} while (0)
3237

3338
using namespace std;
3439
using namespace aoti;
@@ -53,45 +58,11 @@ class ET_EXPERIMENTAL CudaBackend final
5358
: public ::executorch::runtime::BackendInterface {
5459
private:
5560
Error register_shared_library_functions(void* so_handle) const {
56-
AOTInductorModelContainerCreateWithDevice =
57-
reinterpret_cast<AOTInductorModelContainerCreateWithDeviceFunc>(
58-
dlsym(so_handle, "AOTInductorModelContainerCreateWithDevice"));
59-
if (AOTInductorModelContainerCreateWithDevice == nullptr) {
60-
ET_LOG(Error, "Failed to load AOTInductorModelContainerCreateWithDevice");
61-
return Error::AccessFailed;
62-
}
63-
64-
AOTInductorModelContainerDelete =
65-
reinterpret_cast<AOTInductorModelContainerDeleteFunc>(
66-
dlsym(so_handle, "AOTInductorModelContainerDelete"));
67-
if (AOTInductorModelContainerDelete == nullptr) {
68-
ET_LOG(Error, "Failed to load AOTInductorModelContainerDelete");
69-
return Error::AccessFailed;
70-
}
71-
72-
AOTInductorModelContainerGetNumInputs =
73-
reinterpret_cast<AOTInductorModelContainerGetNumInputsFunc>(
74-
dlsym(so_handle, "AOTInductorModelContainerGetNumInputs"));
75-
if (AOTInductorModelContainerGetNumInputs == nullptr) {
76-
ET_LOG(Error, "Failed to load AOTInductorModelContainerGetNumInputs");
77-
return Error::AccessFailed;
78-
}
79-
80-
AOTInductorModelContainerGetNumOutputs =
81-
reinterpret_cast<AOTInductorModelContainerGetNumOutputsFunc>(
82-
dlsym(so_handle, "AOTInductorModelContainerGetNumOutputs"));
83-
if (AOTInductorModelContainerGetNumOutputs == nullptr) {
84-
ET_LOG(Error, "Failed to load AOTInductorModelContainerGetNumOutputs");
85-
return Error::AccessFailed;
86-
}
87-
88-
AOTInductorModelContainerRun =
89-
reinterpret_cast<AOTInductorModelContainerRunFunc>(
90-
dlsym(so_handle, "AOTInductorModelContainerRun"));
91-
if (AOTInductorModelContainerRun == nullptr) {
92-
ET_LOG(Error, "Failed to load AOTInductorModelContainerRun");
93-
return Error::AccessFailed;
94-
}
61+
LOAD_SYMBOL(AOTInductorModelContainerCreateWithDevice, so_handle);
62+
LOAD_SYMBOL(AOTInductorModelContainerDelete, so_handle);
63+
LOAD_SYMBOL(AOTInductorModelContainerGetNumInputs, so_handle);
64+
LOAD_SYMBOL(AOTInductorModelContainerGetNumOutputs, so_handle);
65+
LOAD_SYMBOL(AOTInductorModelContainerRun, so_handle);
9566

9667
return Error::Ok;
9768
}
@@ -122,14 +93,13 @@ class ET_EXPERIMENTAL CudaBackend final
12293

12394
const NamedDataMap* named_data_map = context.get_named_data_map();
12495
auto aoti_cuda_buffer = named_data_map->get_data(so_blob_key.c_str());
125-
if (!aoti_cuda_buffer.ok()) {
126-
ET_LOG(
127-
Error,
128-
"Failed to get data for key %s: 0x%x",
129-
so_blob_key.c_str(),
130-
aoti_cuda_buffer.error());
131-
return aoti_cuda_buffer.error();
132-
}
96+
ET_CHECK_OR_RETURN_ERROR(
97+
aoti_cuda_buffer.ok(),
98+
Internal,
99+
"Failed to get data for key %s: 0x%x",
100+
so_blob_key.c_str(),
101+
static_cast<uint32_t>(aoti_cuda_buffer.error()));
102+
133103
// Generate dynamic temporary file path
134104
filesystem::path temp_dir = filesystem::temp_directory_path();
135105
filesystem::path so_path =
@@ -144,39 +114,35 @@ class ET_EXPERIMENTAL CudaBackend final
144114
"Writing %zu bytes to %s",
145115
aoti_cuda_buffer->size(),
146116
so_path.c_str());
117+
147118
outfile.write(
148119
static_cast<const char*>(aoti_cuda_buffer->data()),
149120
aoti_cuda_buffer->size());
150121

151-
if (!outfile) {
152-
ET_LOG(Error, "Failed to write to file %s", so_path.c_str());
153-
return Error::AccessFailed;
154-
}
122+
ET_CHECK_OR_RETURN_ERROR(
123+
outfile, AccessFailed, "Failed to write to file %s", so_path.c_str());
124+
155125
// Finish writing the file to disk
156126
outfile.close();
157127

158128
// Load the ELF using dlopen
159129
void* so_handle = dlopen(so_path.c_str(), RTLD_LAZY | RTLD_LOCAL);
160-
if (so_handle == nullptr) {
161-
ET_LOG(Error, "Failed to load shared library: %s", dlerror());
162-
return Error::AccessFailed;
163-
}
130+
ET_CHECK_OR_RETURN_ERROR(
131+
so_handle != nullptr,
132+
AccessFailed,
133+
"Failed to load shared library: %s",
134+
dlerror());
164135

165136
processed->Free();
166137

167138
// Register all shared library functions
168-
Error reg_err = register_shared_library_functions(so_handle);
169-
if (reg_err != Error::Ok) {
170-
return reg_err;
171-
}
139+
ET_CHECK_OK_OR_RETURN_ERROR(register_shared_library_functions(so_handle));
172140

173141
AOTInductorModelContainerHandle container_handle = nullptr;
174142

175-
AOTIRuntimeError err = AOTInductorModelContainerCreateWithDevice(
176-
&container_handle, 1, "cuda", nullptr);
177-
if (err != Error::Ok) {
178-
return err;
179-
}
143+
ET_CHECK_OK_OR_RETURN_ERROR(AOTInductorModelContainerCreateWithDevice(
144+
&container_handle, 1, "cuda", nullptr));
145+
180146
ET_LOG(Info, "container_handle = %p", container_handle);
181147

182148
AOTIDelegateHandle* handle = new AOTIDelegateHandle();
@@ -206,15 +172,13 @@ class ET_EXPERIMENTAL CudaBackend final
206172
AOTInductorModelContainerGetNumOutputs(
207173
handle->container_handle, &n_outputs);
208174

209-
if (n_inputs + n_outputs != args.size()) {
210-
ET_LOG(
211-
Error,
212-
"number of user input %zd and output %zd generated from AOT Inductor does not match ET runner's %zd. Exit.",
213-
n_inputs,
214-
n_outputs,
215-
args.size());
216-
return Error::InvalidArgument;
217-
}
175+
ET_CHECK_OR_RETURN_ERROR(
176+
n_inputs + n_outputs == args.size(),
177+
InvalidArgument,
178+
"number of user input %zd and output %zd generated from AOT Inductor does not match ET runner's %zd. Exit.",
179+
n_inputs,
180+
n_outputs,
181+
args.size())
218182

219183
// NOTE: ExecuTorch tensors are always on CPU/host memory
220184
// We need to create GPU copies for CUDA kernel execution
@@ -244,19 +208,20 @@ class ET_EXPERIMENTAL CudaBackend final
244208
0, // device_index = 0
245209
&gpu_input_handle);
246210

247-
if (create_err != Error::Ok) {
248-
ET_LOG(Error, "Failed to create GPU tensor for input %d", i);
249-
return Error::Internal;
250-
}
211+
ET_CHECK_OR_RETURN_ERROR(
212+
create_err == Error::Ok,
213+
Internal,
214+
"Failed to create GPU tensor for input %d",
215+
i);
251216

252217
gpu_inputs[i] = gpu_input_handle;
253218

254219
// Copy data from CPU to GPU
255-
Error copy_err = aoti_torch_copy_(gpu_inputs[i], cpu_tensor, 0);
256-
if (copy_err != Error::Ok) {
257-
ET_LOG(Error, "Failed to copy input %d from CPU to GPU", i);
258-
return Error::Internal;
259-
}
220+
ET_CHECK_OR_RETURN_ERROR(
221+
aoti_torch_copy_(gpu_inputs[i], cpu_tensor, 0) == Error::Ok,
222+
Internal,
223+
"Failed to copy input %d from CPU to GPU",
224+
i);
260225
}
261226
ET_LOG(Info, "Inputs copied to GPU");
262227
// Process output tensors: create GPU counterparts for ExecuTorch CPU
@@ -280,10 +245,11 @@ class ET_EXPERIMENTAL CudaBackend final
280245
0, // device_index = 0
281246
&gpu_output_handle);
282247

283-
if (create_err != Error::Ok) {
284-
ET_LOG(Error, "Failed to create GPU tensor for output %d", i);
285-
return Error::Internal;
286-
}
248+
ET_CHECK_OR_RETURN_ERROR(
249+
create_err == Error::Ok,
250+
Internal,
251+
"Failed to create GPU tensor for output %d",
252+
i);
287253

288254
gpu_outputs[i] = gpu_output_handle;
289255
}
@@ -298,13 +264,11 @@ class ET_EXPERIMENTAL CudaBackend final
298264
handle->cuda_stream, // Pass the actual CUDA stream
299265
nullptr); // proxy_executor_handle can remain nullptr
300266

301-
if (error != Error::Ok) {
302-
ET_LOG(
303-
Error,
304-
"AOTInductorModelContainerRun failed with error code %d",
305-
error);
306-
return Error::Internal;
307-
}
267+
ET_CHECK_OR_RETURN_ERROR(
268+
error == Error::Ok,
269+
Internal,
270+
"AOTInductorModelContainerRun failed with error code %d",
271+
error);
308272

309273
// Copy GPU output results back to CPU output tensors
310274
for (int i = 0; i < n_outputs; i++) {
@@ -320,18 +284,6 @@ class ET_EXPERIMENTAL CudaBackend final
320284
i);
321285
}
322286

323-
// Clean up GPU tensors that we created (ExecuTorch tensors are always
324-
// CPU, so all GPU tensors are our copies)
325-
for (int i = 0; i < n_inputs; i++) {
326-
// All GPU input tensors were created by us, delete them
327-
aoti_torch_delete_tensor_object(gpu_inputs[i]);
328-
}
329-
330-
for (int i = 0; i < n_outputs; i++) {
331-
// All GPU output tensors were created by us, delete them
332-
aoti_torch_delete_tensor_object(gpu_outputs[i]);
333-
}
334-
335287
return Error::Ok;
336288
}
337289

@@ -352,18 +304,13 @@ class ET_EXPERIMENTAL CudaBackend final
352304
handle->cuda_stream = nullptr;
353305
}
354306

355-
// Delete the container BEFORE closing the shared library
356-
if (handle->container_handle != nullptr) {
357-
AOTIRuntimeError delete_result =
358-
AOTInductorModelContainerDelete(handle->container_handle);
359-
if (delete_result != Error::Ok) {
360-
ET_LOG(
361-
Error,
362-
"AOTInductorModelContainerDelete failed with error code %d",
363-
delete_result);
364-
}
365-
handle->container_handle = nullptr;
366-
}
307+
// NOTE: AOTInductorModelContainerDelete does not work correctly with
308+
// multiple .so files. Deleting one container frees shared resources,
309+
// which causes segmentation faults when attempting to delete other
310+
// containers. As a workaround, we skip explicit container deletion
311+
// and defer cleanup to the OS.
312+
// TODO(gasoonjia): Find a proper solution for safe container deletion.
313+
// AOTInductorModelContainerDelete(handle->container_handle);
367314

368315
// Now close the shared library
369316
if (handle->so_handle != nullptr) {
@@ -374,27 +321,25 @@ class ET_EXPERIMENTAL CudaBackend final
374321
if (!handle->so_path.empty()) {
375322
std::error_code remove_error;
376323
std::filesystem::remove(handle->so_path, remove_error);
377-
if (remove_error) {
378-
ET_LOG(
379-
Error,
380-
"Failed to remove temporary shared library %s: %s",
381-
handle->so_path.c_str(),
382-
remove_error.message().c_str());
383-
}
324+
ET_CHECK_OR_LOG_ERROR(
325+
!remove_error,
326+
"Failed to remove temporary shared library %s: %s",
327+
handle->so_path.c_str(),
328+
remove_error.message().c_str());
384329
}
385330

386331
delete handle;
332+
clear_all_tensors();
387333
}
388334
};
389335

390-
} // namespace cuda
336+
} // namespace executorch::backends::cuda
391337

338+
namespace executorch::backends {
392339
namespace {
393340
auto cls = cuda::CudaBackend();
394341
executorch::runtime::Backend backend{"CudaBackend", &cls};
395342
static executorch::runtime::Error success_with_compiler =
396343
register_backend(backend);
397344
} // namespace
398-
399-
} // namespace backends
400-
} // namespace executorch
345+
} // namespace executorch::backends

0 commit comments

Comments
 (0)