Skip to content

Commit 13bdf55

Browse files
committed
refactor cuda_backend.cpp
Pull Request resolved: #14904 This diff does a comprehensive refactor on cuda_backend.cpp. Two main points: 1. Reuse ExecuTorch standard macros (ET_CHECK_OR_RETURN_ERROR and others) to replaces exiting if..else + ET_LOG branches 2. Introduced LOAD_SYMBOL macro to concentrate the symbol loading pipeline. ghstack-source-id: 314984328 @exported-using-ghexport Differential Revision: [D84135844](https://our.internmc.facebook.com/intern/diff/D84135844/)
1 parent 056ccb9 commit 13bdf55

File tree

3 files changed

+91
-107
lines changed

3 files changed

+91
-107
lines changed

backends/aoti/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def define_common_targets():
5151
link_whole = True,
5252
supports_python_dlopen = True,
5353
visibility = ["@EXECUTORCH_CLIENTS"],
54-
deps = [
54+
exported_deps = [
5555
":common_shims",
5656
":model_container",
5757
],

backends/cuda/runtime/TARGETS

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,25 @@ runtime.cxx_library(
3434
("cuda", None, "cuda-lazy"),
3535
],
3636
)
37+
38+
runtime.cxx_library(
39+
name = "cuda_backend",
40+
srcs = [
41+
"cuda_backend.cpp",
42+
],
43+
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
44+
link_whole = True,
45+
supports_python_dlopen = True,
46+
# Constructor needed for backend registration.
47+
compiler_flags = ["-Wno-global-constructors"],
48+
visibility = ["@EXECUTORCH_CLIENTS"],
49+
deps = [
50+
":runtime_shims",
51+
"//executorch/backends/aoti:aoti_common",
52+
"//executorch/runtime/backend:interface",
53+
"//executorch/runtime/core/exec_aten/util:tensor_util",
54+
],
55+
external_deps = [
56+
("cuda", None, "cuda-lazy"),
57+
],
58+
)

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 68 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ namespace executorch {
3030
namespace backends {
3131
namespace cuda {
3232

33+
#define LOAD_SYMBOL(name, handle) \
34+
do { \
35+
name = reinterpret_cast<name##Func>(dlsym(handle, #name)); \
36+
ET_CHECK_OR_RETURN_ERROR( \
37+
name != nullptr, AccessFailed, "Failed to load " #name); \
38+
} while (0)
39+
3340
using namespace std;
3441
using namespace aoti;
3542

@@ -53,45 +60,11 @@ class ET_EXPERIMENTAL CudaBackend final
5360
: public ::executorch::runtime::BackendInterface {
5461
private:
5562
Error register_shared_library_functions(void* so_handle) const {
56-
AOTInductorModelContainerCreateWithDevice =
57-
reinterpret_cast<AOTInductorModelContainerCreateWithDeviceFunc>(
58-
dlsym(so_handle, "AOTInductorModelContainerCreateWithDevice"));
59-
if (AOTInductorModelContainerCreateWithDevice == nullptr) {
60-
ET_LOG(Error, "Failed to load AOTInductorModelContainerCreateWithDevice");
61-
return Error::AccessFailed;
62-
}
63-
64-
AOTInductorModelContainerDelete =
65-
reinterpret_cast<AOTInductorModelContainerDeleteFunc>(
66-
dlsym(so_handle, "AOTInductorModelContainerDelete"));
67-
if (AOTInductorModelContainerDelete == nullptr) {
68-
ET_LOG(Error, "Failed to load AOTInductorModelContainerDelete");
69-
return Error::AccessFailed;
70-
}
71-
72-
AOTInductorModelContainerGetNumInputs =
73-
reinterpret_cast<AOTInductorModelContainerGetNumInputsFunc>(
74-
dlsym(so_handle, "AOTInductorModelContainerGetNumInputs"));
75-
if (AOTInductorModelContainerGetNumInputs == nullptr) {
76-
ET_LOG(Error, "Failed to load AOTInductorModelContainerGetNumInputs");
77-
return Error::AccessFailed;
78-
}
79-
80-
AOTInductorModelContainerGetNumOutputs =
81-
reinterpret_cast<AOTInductorModelContainerGetNumOutputsFunc>(
82-
dlsym(so_handle, "AOTInductorModelContainerGetNumOutputs"));
83-
if (AOTInductorModelContainerGetNumOutputs == nullptr) {
84-
ET_LOG(Error, "Failed to load AOTInductorModelContainerGetNumOutputs");
85-
return Error::AccessFailed;
86-
}
87-
88-
AOTInductorModelContainerRun =
89-
reinterpret_cast<AOTInductorModelContainerRunFunc>(
90-
dlsym(so_handle, "AOTInductorModelContainerRun"));
91-
if (AOTInductorModelContainerRun == nullptr) {
92-
ET_LOG(Error, "Failed to load AOTInductorModelContainerRun");
93-
return Error::AccessFailed;
94-
}
63+
LOAD_SYMBOL(AOTInductorModelContainerCreateWithDevice, so_handle);
64+
LOAD_SYMBOL(AOTInductorModelContainerDelete, so_handle);
65+
LOAD_SYMBOL(AOTInductorModelContainerGetNumInputs, so_handle);
66+
LOAD_SYMBOL(AOTInductorModelContainerGetNumOutputs, so_handle);
67+
LOAD_SYMBOL(AOTInductorModelContainerRun, so_handle);
9568

9669
return Error::Ok;
9770
}
@@ -122,14 +95,13 @@ class ET_EXPERIMENTAL CudaBackend final
12295

12396
const NamedDataMap* named_data_map = context.get_named_data_map();
12497
auto aoti_cuda_buffer = named_data_map->get_data(so_blob_key.c_str());
125-
if (!aoti_cuda_buffer.ok()) {
126-
ET_LOG(
127-
Error,
128-
"Failed to get data for key %s: 0x%x",
129-
so_blob_key.c_str(),
130-
aoti_cuda_buffer.error());
131-
return aoti_cuda_buffer.error();
132-
}
98+
ET_CHECK_OR_RETURN_ERROR(
99+
aoti_cuda_buffer.ok(),
100+
Internal,
101+
"Failed to get data for key %s: 0x%x",
102+
so_blob_key.c_str(),
103+
static_cast<uint32_t>(aoti_cuda_buffer.error()));
104+
133105
// Generate dynamic temporary file path
134106
filesystem::path temp_dir = filesystem::temp_directory_path();
135107
filesystem::path so_path =
@@ -144,39 +116,35 @@ class ET_EXPERIMENTAL CudaBackend final
144116
"Writing %zu bytes to %s",
145117
aoti_cuda_buffer->size(),
146118
so_path.c_str());
119+
147120
outfile.write(
148121
static_cast<const char*>(aoti_cuda_buffer->data()),
149122
aoti_cuda_buffer->size());
150123

151-
if (!outfile) {
152-
ET_LOG(Error, "Failed to write to file %s", so_path.c_str());
153-
return Error::AccessFailed;
154-
}
124+
ET_CHECK_OR_RETURN_ERROR(
125+
outfile, AccessFailed, "Failed to write to file %s", so_path.c_str());
126+
155127
// Finish writing the file to disk
156128
outfile.close();
157129

158130
// Load the ELF using dlopen
159131
void* so_handle = dlopen(so_path.c_str(), RTLD_LAZY | RTLD_LOCAL);
160-
if (so_handle == nullptr) {
161-
ET_LOG(Error, "Failed to load shared library: %s", dlerror());
162-
return Error::AccessFailed;
163-
}
132+
ET_CHECK_OR_RETURN_ERROR(
133+
so_handle != nullptr,
134+
AccessFailed,
135+
"Failed to load shared library: %s",
136+
dlerror());
164137

165138
processed->Free();
166139

167140
// Register all shared library functions
168-
Error reg_err = register_shared_library_functions(so_handle);
169-
if (reg_err != Error::Ok) {
170-
return reg_err;
171-
}
141+
ET_CHECK_OK_OR_RETURN_ERROR(register_shared_library_functions(so_handle));
172142

173143
AOTInductorModelContainerHandle container_handle = nullptr;
174144

175-
AOTIRuntimeError err = AOTInductorModelContainerCreateWithDevice(
176-
&container_handle, 1, "cuda", nullptr);
177-
if (err != Error::Ok) {
178-
return err;
179-
}
145+
ET_CHECK_OK_OR_RETURN_ERROR(AOTInductorModelContainerCreateWithDevice(
146+
&container_handle, 1, "cuda", nullptr));
147+
180148
ET_LOG(Info, "container_handle = %p", container_handle);
181149

182150
AOTIDelegateHandle* handle = new AOTIDelegateHandle();
@@ -206,15 +174,13 @@ class ET_EXPERIMENTAL CudaBackend final
206174
AOTInductorModelContainerGetNumOutputs(
207175
handle->container_handle, &n_outputs);
208176

209-
if (n_inputs + n_outputs != args.size()) {
210-
ET_LOG(
211-
Error,
212-
"number of user input %zd and output %zd generated from AOT Inductor does not match ET runner's %zd. Exit.",
213-
n_inputs,
214-
n_outputs,
215-
args.size());
216-
return Error::InvalidArgument;
217-
}
177+
ET_CHECK_OR_RETURN_ERROR(
178+
n_inputs + n_outputs == args.size(),
179+
InvalidArgument,
180+
"number of user input %zd and output %zd generated from AOT Inductor does not match ET runner's %zd. Exit.",
181+
n_inputs,
182+
n_outputs,
183+
args.size())
218184

219185
// NOTE: ExecuTorch tensors are always on CPU/host memory
220186
// We need to create GPU copies for CUDA kernel execution
@@ -244,19 +210,20 @@ class ET_EXPERIMENTAL CudaBackend final
244210
0, // device_index = 0
245211
&gpu_input_handle);
246212

247-
if (create_err != Error::Ok) {
248-
ET_LOG(Error, "Failed to create GPU tensor for input %d", i);
249-
return Error::Internal;
250-
}
213+
ET_CHECK_OR_RETURN_ERROR(
214+
create_err == Error::Ok,
215+
Internal,
216+
"Failed to create GPU tensor for input %d",
217+
i);
251218

252219
gpu_inputs[i] = gpu_input_handle;
253220

254221
// Copy data from CPU to GPU
255-
Error copy_err = aoti_torch_copy_(gpu_inputs[i], cpu_tensor, 0);
256-
if (copy_err != Error::Ok) {
257-
ET_LOG(Error, "Failed to copy input %d from CPU to GPU", i);
258-
return Error::Internal;
259-
}
222+
ET_CHECK_OR_RETURN_ERROR(
223+
aoti_torch_copy_(gpu_inputs[i], cpu_tensor, 0) == Error::Ok,
224+
Internal,
225+
"Failed to copy input %d from CPU to GPU",
226+
i);
260227
}
261228
ET_LOG(Info, "Inputs copied to GPU");
262229
// Process output tensors: create GPU counterparts for ExecuTorch CPU
@@ -280,10 +247,11 @@ class ET_EXPERIMENTAL CudaBackend final
280247
0, // device_index = 0
281248
&gpu_output_handle);
282249

283-
if (create_err != Error::Ok) {
284-
ET_LOG(Error, "Failed to create GPU tensor for output %d", i);
285-
return Error::Internal;
286-
}
250+
ET_CHECK_OR_RETURN_ERROR(
251+
create_err == Error::Ok,
252+
Internal,
253+
"Failed to create GPU tensor for output %d",
254+
i);
287255

288256
gpu_outputs[i] = gpu_output_handle;
289257
}
@@ -298,13 +266,11 @@ class ET_EXPERIMENTAL CudaBackend final
298266
handle->cuda_stream, // Pass the actual CUDA stream
299267
nullptr); // proxy_executor_handle can remain nullptr
300268

301-
if (error != Error::Ok) {
302-
ET_LOG(
303-
Error,
304-
"AOTInductorModelContainerRun failed with error code %d",
305-
error);
306-
return Error::Internal;
307-
}
269+
ET_CHECK_OR_RETURN_ERROR(
270+
error == Error::Ok,
271+
Internal,
272+
"AOTInductorModelContainerRun failed with error code %d",
273+
error);
308274

309275
// Copy GPU output results back to CPU output tensors
310276
for (int i = 0; i < n_outputs; i++) {
@@ -356,12 +322,10 @@ class ET_EXPERIMENTAL CudaBackend final
356322
if (handle->container_handle != nullptr) {
357323
AOTIRuntimeError delete_result =
358324
AOTInductorModelContainerDelete(handle->container_handle);
359-
if (delete_result != Error::Ok) {
360-
ET_LOG(
361-
Error,
362-
"AOTInductorModelContainerDelete failed with error code %d",
363-
delete_result);
364-
}
325+
ET_CHECK_OR_LOG_ERROR(
326+
delete_result == Error::Ok,
327+
"Failed to delete AOTInductorModelContainer with error code %d",
328+
delete_result);
365329
handle->container_handle = nullptr;
366330
}
367331

@@ -374,13 +338,11 @@ class ET_EXPERIMENTAL CudaBackend final
374338
if (!handle->so_path.empty()) {
375339
std::error_code remove_error;
376340
std::filesystem::remove(handle->so_path, remove_error);
377-
if (remove_error) {
378-
ET_LOG(
379-
Error,
380-
"Failed to remove temporary shared library %s: %s",
381-
handle->so_path.c_str(),
382-
remove_error.message().c_str());
383-
}
341+
ET_CHECK_OR_LOG_ERROR(
342+
!remove_error,
343+
"Failed to remove temporary shared library %s: %s",
344+
handle->so_path.c_str(),
345+
remove_error.message().c_str());
384346
}
385347

386348
delete handle;

0 commit comments

Comments
 (0)