Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions backends/aoti/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,7 @@ target_compile_options(aoti_common PUBLIC -fexceptions -frtti -fPIC)
target_link_options(aoti_common PUBLIC -Wl,--export-dynamic)

# Link against PyTorch libraries and standard libraries
target_link_libraries(
aoti_common
PUBLIC extension_tensor ${CMAKE_DL_LIBS}
# Link PyTorch libraries for AOTI functions
${TORCH_LIBRARIES}
)
target_link_libraries(aoti_common PUBLIC extension_tensor ${CMAKE_DL_LIBS})
executorch_target_link_options_shared_lib(aoti_common)

install(
Expand Down
2 changes: 2 additions & 0 deletions backends/aoti/aoti_model_container.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ struct AOTIDelegateHandle {
void* so_handle;
std::string so_path;
AOTInductorModelContainerHandle container_handle;
void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header
// dependency
};

} // namespace aoti
Expand Down
6 changes: 4 additions & 2 deletions backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
find_package_torch()

# CUDA-specific AOTI functionality
set(_aoti_cuda_sources runtime/cuda_backend.cpp runtime/shims/memory.cpp
runtime/shims/tensor_attribute.cpp runtime/guard.cpp
set(_aoti_cuda_sources
runtime/cuda_backend.cpp runtime/shims/memory.cpp
runtime/shims/tensor_attribute.cpp runtime/guard.cpp
runtime/shims/cuda_guard.cpp
)
add_library(aoti_cuda STATIC ${_aoti_cuda_sources})
target_include_directories(
Expand Down
2 changes: 2 additions & 0 deletions backends/cuda/runtime/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ runtime.cxx_library(
name = "runtime_shims",
srcs = [
"guard.cpp",
"shims/cuda_guard.cpp",
"shims/memory.cpp",
"shims/tensor_attribute.cpp",
],
headers = [
"guard.h",
"shims/cuda_guard.h",
"shims/memory.h",
"shims/tensor_attribute.h",
"utils.h",
Expand Down
22 changes: 20 additions & 2 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include <cuda_runtime.h>
#include <dlfcn.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/core/error.h>
Expand All @@ -16,14 +17,14 @@

#include <filesystem>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>

// Include our shim layer headers
#include <executorch/backends/aoti/aoti_model_container.h>
#include <executorch/backends/aoti/common_shims.h>
#include <executorch/backends/cuda/runtime/shims/memory.h>
#include <executorch/backends/cuda/runtime/utils.h>

namespace executorch {
namespace backends {
Expand Down Expand Up @@ -182,6 +183,12 @@ class ET_EXPERIMENTAL CudaBackend final
handle->so_handle = so_handle;
handle->so_path = so_path.string();
handle->container_handle = container_handle;

// Create a CUDA stream for asynchronous execution
cudaStream_t cuda_stream;
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&cuda_stream));
handle->cuda_stream = static_cast<void*>(cuda_stream);

return (DelegateHandle*)handle; // Return the handle post-processing
}

Expand Down Expand Up @@ -288,7 +295,7 @@ class ET_EXPERIMENTAL CudaBackend final
n_inputs,
gpu_outputs.data(), // Use GPU output tensors
n_outputs,
nullptr, // Pass the actual CUDA stream!
handle->cuda_stream, // Pass the actual CUDA stream
nullptr); // proxy_executor_handle can remain nullptr

if (error != Error::Ok) {
Expand Down Expand Up @@ -334,6 +341,17 @@ class ET_EXPERIMENTAL CudaBackend final
}
AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_;

// Destroy the CUDA stream if it exists
if (handle->cuda_stream != nullptr) {
cudaStream_t cuda_stream = static_cast<cudaStream_t>(handle->cuda_stream);
cudaError_t stream_err = cudaStreamDestroy(cuda_stream);
ET_CHECK_OR_LOG_ERROR(
stream_err == cudaSuccess,
"Failed to destroy CUDA stream: %s",
cudaGetErrorString(stream_err));
handle->cuda_stream = nullptr;
}

// Delete the container BEFORE closing the shared library
if (handle->container_handle != nullptr) {
AOTIRuntimeError delete_result =
Expand Down
109 changes: 109 additions & 0 deletions backends/cuda/runtime/shims/cuda_guard.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/cuda/runtime/shims/cuda_guard.h>

namespace executorch {
namespace backends {
namespace cuda {

extern "C" {

AOTITorchError aoti_torch_create_cuda_guard(
int32_t device_index,
CUDAGuardHandle* ret_guard) {
ET_CHECK_OR_RETURN_ERROR(
ret_guard != nullptr,
InvalidArgument,
"aoti_torch_create_cuda_guard failed: ret_guard is null");

auto result = CUDAGuard::create(device_index);
if (!result.ok()) {
return result.error();
}
*ret_guard = new CUDAGuard(std::move(result.get()));
return Error::Ok;
}

AOTITorchError aoti_torch_delete_cuda_guard(CUDAGuardHandle guard) {
ET_CHECK_OR_RETURN_ERROR(
guard != nullptr,
InvalidArgument,
"aoti_torch_delete_cuda_guard failed: guard is null");

delete guard;
return Error::Ok;
}

AOTITorchError aoti_torch_cuda_guard_set_index(
CUDAGuardHandle guard,
int32_t device_index) {
ET_CHECK_OR_RETURN_ERROR(
guard != nullptr,
InvalidArgument,
"aoti_torch_cuda_guard_set_index failed: guard is null");

ET_CHECK_OK_OR_RETURN_ERROR(guard->set_index(device_index));
return Error::Ok;
}

AOTITorchError aoti_torch_create_cuda_stream_guard(
void* stream,
int32_t device_index,
CUDAStreamGuardHandle* ret_guard) {
ET_CHECK_OR_RETURN_ERROR(
ret_guard != nullptr,
InvalidArgument,
"aoti_torch_create_cuda_stream_guard failed: ret_guard is null");

ET_CHECK_OR_RETURN_ERROR(
stream != nullptr,
InvalidArgument,
"aoti_torch_create_cuda_stream_guard failed: stream is null");

auto result =
CUDAStreamGuard::create(static_cast<cudaStream_t>(stream), device_index);
if (!result.ok()) {
return result.error();
}
*ret_guard = new CUDAStreamGuard(std::move(result.get()));
return Error::Ok;
}

AOTITorchError aoti_torch_delete_cuda_stream_guard(
CUDAStreamGuardHandle guard) {
ET_CHECK_OR_RETURN_ERROR(
guard != nullptr,
InvalidArgument,
"aoti_torch_delete_cuda_stream_guard failed: guard is null");

delete guard;
return Error::Ok;
}

AOTITorchError aoti_torch_get_current_cuda_stream(
int32_t device_index,
void** ret_stream) {
ET_CHECK_OR_RETURN_ERROR(
ret_stream != nullptr,
InvalidArgument,
"aoti_torch_get_current_cuda_stream failed: ret_stream is null");

auto result = getCurrentCUDAStream(device_index);
if (!result.ok()) {
return result.error();
}
*ret_stream = static_cast<void*>(result.get());
return Error::Ok;
}

} // extern "C"

} // namespace cuda
} // namespace backends
} // namespace executorch
104 changes: 104 additions & 0 deletions backends/cuda/runtime/shims/cuda_guard.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cuda_runtime.h>
#include <executorch/backends/aoti/common_shims.h>
#include <executorch/backends/cuda/runtime/guard.h>
#include <cstdint>

namespace executorch {
namespace backends {
namespace cuda {

using executorch::backends::aoti::AOTITorchError;

extern "C" {

// Handle types for CUDA guards
using CUDAGuardHandle = CUDAGuard*;
using CUDAStreamGuardHandle = CUDAStreamGuard*;

/**
* Creates a CUDA device guard that sets the current device and restores it
* upon destruction.
*
* @param device_index The device index to set as current
* @param ret_guard Output parameter for the created guard handle (must not be
* null)
* @return AOTITorchError error code (Error::Ok on success, or an error code on
* failure)
*/
AOTITorchError aoti_torch_create_cuda_guard(
int32_t device_index,
CUDAGuardHandle* ret_guard);

/**
* Deletes a CUDA device guard and frees its associated resources.
*
* @param guard Handle to the guard to be deleted
* @return AOTITorchError error code (Error::Ok on success, or an error code on
* failure)
*/
AOTITorchError aoti_torch_delete_cuda_guard(CUDAGuardHandle guard);

/**
* Sets the CUDA device to a new index for an existing guard.
*
* @param guard Handle to the guard
* @param device_index The device index to set as current
* @return AOTITorchError error code (Error::Ok on success, or an error code on
* failure)
*/
AOTITorchError aoti_torch_cuda_guard_set_index(
CUDAGuardHandle guard,
int32_t device_index);

/**
* Creates a CUDA stream guard that sets the current device and stream,
* restoring both upon destruction.
*
* @param stream The CUDA stream to set as current
* @param device_index The device index for the stream
* @param ret_guard Output parameter for the created guard handle (must not be
* null)
* @return AOTITorchError error code (Error::Ok on success, or an error code on
* failure)
*/
AOTITorchError aoti_torch_create_cuda_stream_guard(
void* stream,
int32_t device_index,
CUDAStreamGuardHandle* ret_guard);

/**
* Deletes a CUDA stream guard and frees its associated resources.
*
* @param guard Handle to the stream guard to be deleted
* @return AOTITorchError error code (Error::Ok on success, or an error code on
* failure)
*/
AOTITorchError aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard);

/**
* Gets the current CUDA stream for a specified device.
*
* @param device_index The device index (-1 to use current device)
* @param ret_stream Output parameter for the current stream (must not be null)
* @return AOTITorchError error code (Error::Ok on success, or an error code on
* failure)
*/
AOTITorchError aoti_torch_get_current_cuda_stream(
int32_t device_index,
void** ret_stream);

} // extern "C"

} // namespace cuda
} // namespace backends
} // namespace executorch
1 change: 1 addition & 0 deletions backends/cuda/runtime/shims/tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ def define_common_targets():
cuda_shim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2")
cuda_shim_cpp_unittest("aoti_torch__reinterpret_tensor")
cuda_shim_cpp_unittest("aoti_torch_copy_")
cuda_shim_cpp_unittest("aoti_torch_cuda_guard")
Loading
Loading