Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions backends/aoti/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,7 @@ target_compile_options(aoti_common PUBLIC -fexceptions -frtti -fPIC)
target_link_options(aoti_common PUBLIC -Wl,--export-dynamic)

# Link against PyTorch libraries and standard libraries
target_link_libraries(
aoti_common
PUBLIC extension_tensor ${CMAKE_DL_LIBS}
# Link PyTorch libraries for AOTI functions
${TORCH_LIBRARIES}
)
target_link_libraries(aoti_common PUBLIC extension_tensor ${CMAKE_DL_LIBS})
executorch_target_link_options_shared_lib(aoti_common)

install(
Expand Down
2 changes: 2 additions & 0 deletions backends/aoti/aoti_model_container.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ struct AOTIDelegateHandle {
void* so_handle;
std::string so_path;
AOTInductorModelContainerHandle container_handle;
void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header
// dependency
};

} // namespace aoti
Expand Down
6 changes: 4 additions & 2 deletions backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
find_package_torch()

# CUDA-specific AOTI functionality
set(_aoti_cuda_sources runtime/cuda_backend.cpp runtime/shims/memory.cpp
runtime/shims/tensor_attribute.cpp runtime/guard.cpp
set(_aoti_cuda_sources
runtime/cuda_backend.cpp runtime/shims/memory.cpp
runtime/shims/tensor_attribute.cpp runtime/guard.cpp
runtime/shims/cuda_guard.cpp
)
add_library(aoti_cuda STATIC ${_aoti_cuda_sources})
target_include_directories(
Expand Down
2 changes: 2 additions & 0 deletions backends/cuda/runtime/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ runtime.cxx_library(
name = "runtime_shims",
srcs = [
"guard.cpp",
"shims/cuda_guard.cpp",
"shims/memory.cpp",
"shims/tensor_attribute.cpp",
],
headers = [
"guard.h",
"shims/cuda_guard.h",
"shims/memory.h",
"shims/tensor_attribute.h",
"utils.h",
Expand Down
22 changes: 20 additions & 2 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include <cuda_runtime.h>
#include <dlfcn.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/core/error.h>
Expand All @@ -16,14 +17,14 @@

#include <filesystem>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>

// Include our shim layer headers
#include <executorch/backends/aoti/aoti_model_container.h>
#include <executorch/backends/aoti/common_shims.h>
#include <executorch/backends/cuda/runtime/shims/memory.h>
#include <executorch/backends/cuda/runtime/utils.h>

namespace executorch {
namespace backends {
Expand Down Expand Up @@ -182,6 +183,12 @@ class ET_EXPERIMENTAL CudaBackend final
handle->so_handle = so_handle;
handle->so_path = so_path.string();
handle->container_handle = container_handle;

// Create a CUDA stream for asynchronous execution
cudaStream_t cuda_stream;
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&cuda_stream));
handle->cuda_stream = static_cast<void*>(cuda_stream);

return (DelegateHandle*)handle; // Return the handle post-processing
}

Expand Down Expand Up @@ -288,7 +295,7 @@ class ET_EXPERIMENTAL CudaBackend final
n_inputs,
gpu_outputs.data(), // Use GPU output tensors
n_outputs,
nullptr, // Pass the actual CUDA stream!
handle->cuda_stream, // Pass the actual CUDA stream
nullptr); // proxy_executor_handle can remain nullptr

if (error != Error::Ok) {
Expand Down Expand Up @@ -334,6 +341,17 @@ class ET_EXPERIMENTAL CudaBackend final
}
AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_;

// Destroy the CUDA stream if it exists
if (handle->cuda_stream != nullptr) {
cudaStream_t cuda_stream = static_cast<cudaStream_t>(handle->cuda_stream);
cudaError_t stream_err = cudaStreamDestroy(cuda_stream);
ET_CHECK_OR_LOG_ERROR(
stream_err == cudaSuccess,
"Failed to destroy CUDA stream: %s",
cudaGetErrorString(stream_err));
handle->cuda_stream = nullptr;
}

// Delete the container BEFORE closing the shared library
if (handle->container_handle != nullptr) {
AOTIRuntimeError delete_result =
Expand Down
109 changes: 109 additions & 0 deletions backends/cuda/runtime/shims/cuda_guard.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/cuda/runtime/shims/cuda_guard.h>

namespace executorch {
namespace backends {
namespace cuda {

extern "C" {

AOTITorchError aoti_torch_create_cuda_guard(
int32_t device_index,
CUDAGuardHandle* ret_guard) {
ET_CHECK_OR_RETURN_ERROR(
ret_guard != nullptr,
InvalidArgument,
"aoti_torch_create_cuda_guard failed: ret_guard is null");

auto result = CUDAGuard::create(device_index);
if (!result.ok()) {
return result.error();
}
*ret_guard = new CUDAGuard(std::move(result.get()));
return Error::Ok;
}

AOTITorchError aoti_torch_delete_cuda_guard(CUDAGuardHandle guard) {
ET_CHECK_OR_RETURN_ERROR(
guard != nullptr,
InvalidArgument,
"aoti_torch_delete_cuda_guard failed: guard is null");

delete guard;
return Error::Ok;
}

AOTITorchError aoti_torch_cuda_guard_set_index(
CUDAGuardHandle guard,
int32_t device_index) {
ET_CHECK_OR_RETURN_ERROR(
guard != nullptr,
InvalidArgument,
"aoti_torch_cuda_guard_set_index failed: guard is null");

ET_CHECK_OK_OR_RETURN_ERROR(guard->set_index(device_index));
return Error::Ok;
}

AOTITorchError aoti_torch_create_cuda_stream_guard(
void* stream,
int32_t device_index,
CUDAStreamGuardHandle* ret_guard) {
ET_CHECK_OR_RETURN_ERROR(
ret_guard != nullptr,
InvalidArgument,
"aoti_torch_create_cuda_stream_guard failed: ret_guard is null");

ET_CHECK_OR_RETURN_ERROR(
stream != nullptr,
InvalidArgument,
"aoti_torch_create_cuda_stream_guard failed: stream is null");

auto result =
CUDAStreamGuard::create(static_cast<cudaStream_t>(stream), device_index);
if (!result.ok()) {
return result.error();
}
*ret_guard = new CUDAStreamGuard(std::move(result.get()));
return Error::Ok;
}

AOTITorchError aoti_torch_delete_cuda_stream_guard(
CUDAStreamGuardHandle guard) {
ET_CHECK_OR_RETURN_ERROR(
guard != nullptr,
InvalidArgument,
"aoti_torch_delete_cuda_stream_guard failed: guard is null");

delete guard;
return Error::Ok;
}

AOTITorchError aoti_torch_get_current_cuda_stream(
int32_t device_index,
void** ret_stream) {
ET_CHECK_OR_RETURN_ERROR(
ret_stream != nullptr,
InvalidArgument,
"aoti_torch_get_current_cuda_stream failed: ret_stream is null");

auto result = getCurrentCUDAStream(device_index);
if (!result.ok()) {
return result.error();
}
*ret_stream = static_cast<void*>(result.get());
return Error::Ok;
}

} // extern "C"

} // namespace cuda
} // namespace backends
} // namespace executorch
104 changes: 104 additions & 0 deletions backends/cuda/runtime/shims/cuda_guard.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cuda_runtime.h>
#include <executorch/backends/aoti/common_shims.h>
#include <executorch/backends/cuda/runtime/guard.h>
#include <cstdint>

namespace executorch {
namespace backends {
namespace cuda {

using executorch::backends::aoti::AOTITorchError;

extern "C" {

// Handle types for CUDA guards
using CUDAGuardHandle = CUDAGuard*;
using CUDAStreamGuardHandle = CUDAStreamGuard*;

/**
* Creates a CUDA device guard that sets the current device and restores it
* upon destruction.
*
* @param device_index The device index to set as current
* @param ret_guard Output parameter for the created guard handle (must not be
* null)
* @return AOTITorchError error code (Error::Ok on success, or an error code on
* failure)
*/
AOTITorchError aoti_torch_create_cuda_guard(
int32_t device_index,
CUDAGuardHandle* ret_guard);

/**
* Deletes a CUDA device guard and frees its associated resources.
*
* @param guard Handle to the guard to be deleted
* @return AOTITorchError error code (Error::Ok on success, or an error code on
* failure)
*/
AOTITorchError aoti_torch_delete_cuda_guard(CUDAGuardHandle guard);

/**
* Sets the CUDA device to a new index for an existing guard.
*
* @param guard Handle to the guard
* @param device_index The device index to set as current
* @return AOTITorchError error code (Error::Ok on success, or an error code on
* failure)
*/
AOTITorchError aoti_torch_cuda_guard_set_index(
CUDAGuardHandle guard,
int32_t device_index);

/**
* Creates a CUDA stream guard that sets the current device and stream,
* restoring both upon destruction.
*
* @param stream The CUDA stream to set as current
* @param device_index The device index for the stream
* @param ret_guard Output parameter for the created guard handle (must not be
* null)
* @return AOTITorchError error code (Error::Ok on success, or an error code on
* failure)
*/
AOTITorchError aoti_torch_create_cuda_stream_guard(
void* stream,
int32_t device_index,
CUDAStreamGuardHandle* ret_guard);

/**
* Deletes a CUDA stream guard and frees its associated resources.
*
* @param guard Handle to the stream guard to be deleted
* @return AOTITorchError error code (Error::Ok on success, or an error code on
* failure)
*/
AOTITorchError aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard);

/**
* Gets the current CUDA stream for a specified device.
*
* @param device_index The device index (-1 to use current device)
* @param ret_stream Output parameter for the current stream (must not be null)
* @return AOTITorchError error code (Error::Ok on success, or an error code on
* failure)
*/
AOTITorchError aoti_torch_get_current_cuda_stream(
int32_t device_index,
void** ret_stream);

} // extern "C"

} // namespace cuda
} // namespace backends
} // namespace executorch
1 change: 1 addition & 0 deletions backends/cuda/runtime/shims/tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ def define_common_targets():
cuda_shim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2")
cuda_shim_cpp_unittest("aoti_torch__reinterpret_tensor")
cuda_shim_cpp_unittest("aoti_torch_copy_")
cuda_shim_cpp_unittest("aoti_torch_cuda_guard")
Loading
Loading