Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/aoti/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def define_common_targets():
link_whole = True,
supports_python_dlopen = True,
visibility = ["@EXECUTORCH_CLIENTS"],
deps = [
exported_deps = [
":common_shims",
":model_container",
],
Expand Down
22 changes: 22 additions & 0 deletions backends/cuda/runtime/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,25 @@ runtime.cxx_library(
("cuda", None, "cuda-lazy"),
],
)

runtime.cxx_library(
name = "cuda_backend",
srcs = [
"cuda_backend.cpp",
],
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
link_whole = True,
supports_python_dlopen = True,
# Constructor needed for backend registration.
compiler_flags = ["-Wno-global-constructors"],
visibility = ["@EXECUTORCH_CLIENTS"],
deps = [
":runtime_shims",
"//executorch/backends/aoti:aoti_common",
"//executorch/runtime/backend:interface",
"//executorch/runtime/core/exec_aten/util:tensor_util",
],
external_deps = [
("cuda", None, "cuda-lazy"),
],
)
174 changes: 68 additions & 106 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ namespace executorch {
namespace backends {
namespace cuda {

#define LOAD_SYMBOL(name, handle) \
do { \
name = reinterpret_cast<name##Func>(dlsym(handle, #name)); \
ET_CHECK_OR_RETURN_ERROR( \
name != nullptr, AccessFailed, "Failed to load " #name); \
} while (0)

using namespace std;
using namespace aoti;

Expand All @@ -53,45 +60,11 @@ class ET_EXPERIMENTAL CudaBackend final
: public ::executorch::runtime::BackendInterface {
private:
Error register_shared_library_functions(void* so_handle) const {
AOTInductorModelContainerCreateWithDevice =
reinterpret_cast<AOTInductorModelContainerCreateWithDeviceFunc>(
dlsym(so_handle, "AOTInductorModelContainerCreateWithDevice"));
if (AOTInductorModelContainerCreateWithDevice == nullptr) {
ET_LOG(Error, "Failed to load AOTInductorModelContainerCreateWithDevice");
return Error::AccessFailed;
}

AOTInductorModelContainerDelete =
reinterpret_cast<AOTInductorModelContainerDeleteFunc>(
dlsym(so_handle, "AOTInductorModelContainerDelete"));
if (AOTInductorModelContainerDelete == nullptr) {
ET_LOG(Error, "Failed to load AOTInductorModelContainerDelete");
return Error::AccessFailed;
}

AOTInductorModelContainerGetNumInputs =
reinterpret_cast<AOTInductorModelContainerGetNumInputsFunc>(
dlsym(so_handle, "AOTInductorModelContainerGetNumInputs"));
if (AOTInductorModelContainerGetNumInputs == nullptr) {
ET_LOG(Error, "Failed to load AOTInductorModelContainerGetNumInputs");
return Error::AccessFailed;
}

AOTInductorModelContainerGetNumOutputs =
reinterpret_cast<AOTInductorModelContainerGetNumOutputsFunc>(
dlsym(so_handle, "AOTInductorModelContainerGetNumOutputs"));
if (AOTInductorModelContainerGetNumOutputs == nullptr) {
ET_LOG(Error, "Failed to load AOTInductorModelContainerGetNumOutputs");
return Error::AccessFailed;
}

AOTInductorModelContainerRun =
reinterpret_cast<AOTInductorModelContainerRunFunc>(
dlsym(so_handle, "AOTInductorModelContainerRun"));
if (AOTInductorModelContainerRun == nullptr) {
ET_LOG(Error, "Failed to load AOTInductorModelContainerRun");
return Error::AccessFailed;
}
LOAD_SYMBOL(AOTInductorModelContainerCreateWithDevice, so_handle);
LOAD_SYMBOL(AOTInductorModelContainerDelete, so_handle);
LOAD_SYMBOL(AOTInductorModelContainerGetNumInputs, so_handle);
LOAD_SYMBOL(AOTInductorModelContainerGetNumOutputs, so_handle);
LOAD_SYMBOL(AOTInductorModelContainerRun, so_handle);

return Error::Ok;
}
Expand Down Expand Up @@ -122,14 +95,13 @@ class ET_EXPERIMENTAL CudaBackend final

const NamedDataMap* named_data_map = context.get_named_data_map();
auto aoti_cuda_buffer = named_data_map->get_data(so_blob_key.c_str());
if (!aoti_cuda_buffer.ok()) {
ET_LOG(
Error,
"Failed to get data for key %s: 0x%x",
so_blob_key.c_str(),
aoti_cuda_buffer.error());
return aoti_cuda_buffer.error();
}
ET_CHECK_OR_RETURN_ERROR(
aoti_cuda_buffer.ok(),
Internal,
"Failed to get data for key %s: 0x%x",
so_blob_key.c_str(),
static_cast<uint32_t>(aoti_cuda_buffer.error()));

// Generate dynamic temporary file path
filesystem::path temp_dir = filesystem::temp_directory_path();
filesystem::path so_path =
Expand All @@ -144,39 +116,35 @@ class ET_EXPERIMENTAL CudaBackend final
"Writing %zu bytes to %s",
aoti_cuda_buffer->size(),
so_path.c_str());

outfile.write(
static_cast<const char*>(aoti_cuda_buffer->data()),
aoti_cuda_buffer->size());

if (!outfile) {
ET_LOG(Error, "Failed to write to file %s", so_path.c_str());
return Error::AccessFailed;
}
ET_CHECK_OR_RETURN_ERROR(
outfile, AccessFailed, "Failed to write to file %s", so_path.c_str());

// Finish writing the file to disk
outfile.close();

// Load the ELF using dlopen
void* so_handle = dlopen(so_path.c_str(), RTLD_LAZY | RTLD_LOCAL);
if (so_handle == nullptr) {
ET_LOG(Error, "Failed to load shared library: %s", dlerror());
return Error::AccessFailed;
}
ET_CHECK_OR_RETURN_ERROR(
so_handle != nullptr,
AccessFailed,
"Failed to load shared library: %s",
dlerror());

processed->Free();

// Register all shared library functions
Error reg_err = register_shared_library_functions(so_handle);
if (reg_err != Error::Ok) {
return reg_err;
}
ET_CHECK_OK_OR_RETURN_ERROR(register_shared_library_functions(so_handle));

AOTInductorModelContainerHandle container_handle = nullptr;

AOTIRuntimeError err = AOTInductorModelContainerCreateWithDevice(
&container_handle, 1, "cuda", nullptr);
if (err != Error::Ok) {
return err;
}
ET_CHECK_OK_OR_RETURN_ERROR(AOTInductorModelContainerCreateWithDevice(
&container_handle, 1, "cuda", nullptr));

ET_LOG(Info, "container_handle = %p", container_handle);

AOTIDelegateHandle* handle = new AOTIDelegateHandle();
Expand Down Expand Up @@ -206,15 +174,13 @@ class ET_EXPERIMENTAL CudaBackend final
AOTInductorModelContainerGetNumOutputs(
handle->container_handle, &n_outputs);

if (n_inputs + n_outputs != args.size()) {
ET_LOG(
Error,
"number of user input %zd and output %zd generated from AOT Inductor does not match ET runner's %zd. Exit.",
n_inputs,
n_outputs,
args.size());
return Error::InvalidArgument;
}
ET_CHECK_OR_RETURN_ERROR(
n_inputs + n_outputs == args.size(),
InvalidArgument,
"number of user input %zd and output %zd generated from AOT Inductor does not match ET runner's %zd. Exit.",
n_inputs,
n_outputs,
args.size())

// NOTE: ExecuTorch tensors are always on CPU/host memory
// We need to create GPU copies for CUDA kernel execution
Expand Down Expand Up @@ -244,19 +210,20 @@ class ET_EXPERIMENTAL CudaBackend final
0, // device_index = 0
&gpu_input_handle);

if (create_err != Error::Ok) {
ET_LOG(Error, "Failed to create GPU tensor for input %d", i);
return Error::Internal;
}
ET_CHECK_OR_RETURN_ERROR(
create_err == Error::Ok,
Internal,
"Failed to create GPU tensor for input %d",
i);

gpu_inputs[i] = gpu_input_handle;

// Copy data from CPU to GPU
Error copy_err = aoti_torch_copy_(gpu_inputs[i], cpu_tensor, 0);
if (copy_err != Error::Ok) {
ET_LOG(Error, "Failed to copy input %d from CPU to GPU", i);
return Error::Internal;
}
ET_CHECK_OR_RETURN_ERROR(
aoti_torch_copy_(gpu_inputs[i], cpu_tensor, 0) == Error::Ok,
Internal,
"Failed to copy input %d from CPU to GPU",
i);
}
ET_LOG(Info, "Inputs copied to GPU");
// Process output tensors: create GPU counterparts for ExecuTorch CPU
Expand All @@ -280,10 +247,11 @@ class ET_EXPERIMENTAL CudaBackend final
0, // device_index = 0
&gpu_output_handle);

if (create_err != Error::Ok) {
ET_LOG(Error, "Failed to create GPU tensor for output %d", i);
return Error::Internal;
}
ET_CHECK_OR_RETURN_ERROR(
create_err == Error::Ok,
Internal,
"Failed to create GPU tensor for output %d",
i);

gpu_outputs[i] = gpu_output_handle;
}
Expand All @@ -298,13 +266,11 @@ class ET_EXPERIMENTAL CudaBackend final
handle->cuda_stream, // Pass the actual CUDA stream
nullptr); // proxy_executor_handle can remain nullptr

if (error != Error::Ok) {
ET_LOG(
Error,
"AOTInductorModelContainerRun failed with error code %d",
error);
return Error::Internal;
}
ET_CHECK_OR_RETURN_ERROR(
error == Error::Ok,
Internal,
"AOTInductorModelContainerRun failed with error code %d",
error);

// Copy GPU output results back to CPU output tensors
for (int i = 0; i < n_outputs; i++) {
Expand Down Expand Up @@ -356,12 +322,10 @@ class ET_EXPERIMENTAL CudaBackend final
if (handle->container_handle != nullptr) {
AOTIRuntimeError delete_result =
AOTInductorModelContainerDelete(handle->container_handle);
if (delete_result != Error::Ok) {
ET_LOG(
Error,
"AOTInductorModelContainerDelete failed with error code %d",
delete_result);
}
ET_CHECK_OR_LOG_ERROR(
delete_result == Error::Ok,
"Failed to delete AOTInductorModelContainer with error code %d",
delete_result);
handle->container_handle = nullptr;
}

Expand All @@ -374,13 +338,11 @@ class ET_EXPERIMENTAL CudaBackend final
if (!handle->so_path.empty()) {
std::error_code remove_error;
std::filesystem::remove(handle->so_path, remove_error);
if (remove_error) {
ET_LOG(
Error,
"Failed to remove temporary shared library %s: %s",
handle->so_path.c_str(),
remove_error.message().c_str());
}
ET_CHECK_OR_LOG_ERROR(
!remove_error,
"Failed to remove temporary shared library %s: %s",
handle->so_path.c_str(),
remove_error.message().c_str());
}

delete handle;
Expand Down
Loading