diff --git a/backends/aoti/CMakeLists.txt b/backends/aoti/CMakeLists.txt index 8d49bcf1f96..ce364f2c4b0 100644 --- a/backends/aoti/CMakeLists.txt +++ b/backends/aoti/CMakeLists.txt @@ -41,12 +41,7 @@ target_compile_options(aoti_common PUBLIC -fexceptions -frtti -fPIC) target_link_options(aoti_common PUBLIC -Wl,--export-dynamic) # Link against PyTorch libraries and standard libraries -target_link_libraries( - aoti_common - PUBLIC extension_tensor ${CMAKE_DL_LIBS} - # Link PyTorch libraries for AOTI functions - ${TORCH_LIBRARIES} -) +target_link_libraries(aoti_common PUBLIC extension_tensor ${CMAKE_DL_LIBS}) executorch_target_link_options_shared_lib(aoti_common) install( diff --git a/backends/aoti/aoti_model_container.h b/backends/aoti/aoti_model_container.h index 844bd2d5a77..9b185327172 100644 --- a/backends/aoti/aoti_model_container.h +++ b/backends/aoti/aoti_model_container.h @@ -77,6 +77,8 @@ struct AOTIDelegateHandle { void* so_handle; std::string so_path; AOTInductorModelContainerHandle container_handle; + void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header + // dependency }; } // namespace aoti diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index acbb7adc87f..dc5b1b786f8 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -35,8 +35,10 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) find_package_torch() # CUDA-specific AOTI functionality -set(_aoti_cuda_sources runtime/cuda_backend.cpp runtime/shims/memory.cpp - runtime/shims/tensor_attribute.cpp runtime/guard.cpp +set(_aoti_cuda_sources + runtime/cuda_backend.cpp runtime/shims/memory.cpp + runtime/shims/tensor_attribute.cpp runtime/guard.cpp + runtime/shims/cuda_guard.cpp ) add_library(aoti_cuda STATIC ${_aoti_cuda_sources}) target_include_directories( diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS index c4b778eccc5..0386b5a008d 100644 --- a/backends/cuda/runtime/TARGETS +++ b/backends/cuda/runtime/TARGETS @@ -6,11 +6,13 @@ runtime.cxx_library( name = "runtime_shims", srcs = [ "guard.cpp", + "shims/cuda_guard.cpp", "shims/memory.cpp", "shims/tensor_attribute.cpp", ], headers = [ "guard.h", + "shims/cuda_guard.h", "shims/memory.h", "shims/tensor_attribute.h", "utils.h", diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index 08031ce6a26..5f113b1ce68 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include #include @@ -16,7 +17,6 @@ #include #include -#include #include #include @@ -24,6 +24,7 @@ #include #include #include +#include namespace executorch { namespace backends { @@ -182,6 +183,12 @@ class ET_EXPERIMENTAL CudaBackend final handle->so_handle = so_handle; handle->so_path = so_path.string(); handle->container_handle = container_handle; + + // Create a CUDA stream for asynchronous execution + cudaStream_t cuda_stream; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&cuda_stream)); + handle->cuda_stream = static_cast(cuda_stream); + return (DelegateHandle*)handle; // Return the handle post-processing } @@ -288,7 +295,7 @@ class ET_EXPERIMENTAL CudaBackend final n_inputs, gpu_outputs.data(), // Use GPU output tensors n_outputs, - nullptr, // Pass the actual CUDA stream! + handle->cuda_stream, // Pass the actual CUDA stream nullptr); // proxy_executor_handle can remain nullptr if (error != Error::Ok) { @@ -334,6 +341,17 @@ class ET_EXPERIMENTAL CudaBackend final } AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_; + // Destroy the CUDA stream if it exists + if (handle->cuda_stream != nullptr) { + cudaStream_t cuda_stream = static_cast(handle->cuda_stream); + cudaError_t stream_err = cudaStreamDestroy(cuda_stream); + ET_CHECK_OR_LOG_ERROR( + stream_err == cudaSuccess, + "Failed to destroy CUDA stream: %s", + cudaGetErrorString(stream_err)); + handle->cuda_stream = nullptr; + } + // Delete the container BEFORE closing the shared library if (handle->container_handle != nullptr) { AOTIRuntimeError delete_result = diff --git a/backends/cuda/runtime/shims/cuda_guard.cpp b/backends/cuda/runtime/shims/cuda_guard.cpp new file mode 100644 index 00000000000..5740d0bf654 --- /dev/null +++ b/backends/cuda/runtime/shims/cuda_guard.cpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace executorch { +namespace backends { +namespace cuda { + +extern "C" { + +AOTITorchError aoti_torch_create_cuda_guard( + int32_t device_index, + CUDAGuardHandle* ret_guard) { + ET_CHECK_OR_RETURN_ERROR( + ret_guard != nullptr, + InvalidArgument, + "aoti_torch_create_cuda_guard failed: ret_guard is null"); + + auto result = CUDAGuard::create(device_index); + if (!result.ok()) { + return result.error(); + } + *ret_guard = new CUDAGuard(std::move(result.get())); + return Error::Ok; +} + +AOTITorchError aoti_torch_delete_cuda_guard(CUDAGuardHandle guard) { + ET_CHECK_OR_RETURN_ERROR( + guard != nullptr, + InvalidArgument, + "aoti_torch_delete_cuda_guard failed: guard is null"); + + delete guard; + return Error::Ok; +} + +AOTITorchError aoti_torch_cuda_guard_set_index( + CUDAGuardHandle guard, + int32_t device_index) { + ET_CHECK_OR_RETURN_ERROR( + guard != nullptr, + InvalidArgument, + "aoti_torch_cuda_guard_set_index failed: guard is null"); + + ET_CHECK_OK_OR_RETURN_ERROR(guard->set_index(device_index)); + return Error::Ok; +} + +AOTITorchError aoti_torch_create_cuda_stream_guard( + void* stream, + int32_t device_index, + CUDAStreamGuardHandle* ret_guard) { + ET_CHECK_OR_RETURN_ERROR( + ret_guard != nullptr, + InvalidArgument, + "aoti_torch_create_cuda_stream_guard failed: ret_guard is null"); + + ET_CHECK_OR_RETURN_ERROR( + stream != nullptr, + InvalidArgument, + "aoti_torch_create_cuda_stream_guard failed: stream is null"); + + auto result = + CUDAStreamGuard::create(static_cast(stream), device_index); + if (!result.ok()) { + return result.error(); + } + *ret_guard = new CUDAStreamGuard(std::move(result.get())); + return Error::Ok; +} + +AOTITorchError aoti_torch_delete_cuda_stream_guard( + CUDAStreamGuardHandle guard) { + ET_CHECK_OR_RETURN_ERROR( + guard != nullptr, + InvalidArgument, + "aoti_torch_delete_cuda_stream_guard failed: guard is null"); + + delete guard; + return Error::Ok; +} + +AOTITorchError aoti_torch_get_current_cuda_stream( + int32_t device_index, + void** ret_stream) { + ET_CHECK_OR_RETURN_ERROR( + ret_stream != nullptr, + InvalidArgument, + "aoti_torch_get_current_cuda_stream failed: ret_stream is null"); + + auto result = getCurrentCUDAStream(device_index); + if (!result.ok()) { + return result.error(); + } + *ret_stream = static_cast(result.get()); + return Error::Ok; +} + +} // extern "C" + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/shims/cuda_guard.h b/backends/cuda/runtime/shims/cuda_guard.h new file mode 100644 index 00000000000..6da869064a7 --- /dev/null +++ b/backends/cuda/runtime/shims/cuda_guard.h @@ -0,0 +1,104 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace cuda { + +using executorch::backends::aoti::AOTITorchError; + +extern "C" { + +// Handle types for CUDA guards +using CUDAGuardHandle = CUDAGuard*; +using CUDAStreamGuardHandle = CUDAStreamGuard*; + +/** + * Creates a CUDA device guard that sets the current device and restores it + * upon destruction. + * + * @param device_index The device index to set as current + * @param ret_guard Output parameter for the created guard handle (must not be + * null) + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_create_cuda_guard( + int32_t device_index, + CUDAGuardHandle* ret_guard); + +/** + * Deletes a CUDA device guard and frees its associated resources. + * + * @param guard Handle to the guard to be deleted + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_delete_cuda_guard(CUDAGuardHandle guard); + +/** + * Sets the CUDA device to a new index for an existing guard. + * + * @param guard Handle to the guard + * @param device_index The device index to set as current + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_cuda_guard_set_index( + CUDAGuardHandle guard, + int32_t device_index); + +/** + * Creates a CUDA stream guard that sets the current device and stream, + * restoring both upon destruction. + * + * @param stream The CUDA stream to set as current + * @param device_index The device index for the stream + * @param ret_guard Output parameter for the created guard handle (must not be + * null) + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_create_cuda_stream_guard( + void* stream, + int32_t device_index, + CUDAStreamGuardHandle* ret_guard); + +/** + * Deletes a CUDA stream guard and frees its associated resources. + * + * @param guard Handle to the stream guard to be deleted + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard); + +/** + * Gets the current CUDA stream for a specified device. + * + * @param device_index The device index (-1 to use current device) + * @param ret_stream Output parameter for the current stream (must not be null) + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_get_current_cuda_stream( + int32_t device_index, + void** ret_stream); + +} // extern "C" + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index fcb95a0beb7..70f27b86bec 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -32,3 +32,4 @@ def define_common_targets(): cuda_shim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2") cuda_shim_cpp_unittest("aoti_torch__reinterpret_tensor") cuda_shim_cpp_unittest("aoti_torch_copy_") + cuda_shim_cpp_unittest("aoti_torch_cuda_guard") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_guard.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_guard.cpp new file mode 100644 index 00000000000..7527965cdb8 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_guard.cpp @@ -0,0 +1,199 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::aoti; +using namespace executorch::backends::cuda; +using namespace executorch::runtime; + +// TODO(gasoonjia): Multiple device tests were not included due to test +// environment limitations. Will be added in the future. +class AOTITorchCUDAGuardTest : public ::testing::Test { + protected: + void SetUp() override { + et_pal_init(); + + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + ASSERT_EQ(cudaGetDevice(&original_device_), cudaSuccess); + } + + void TearDown() override { + if (cudaGetDeviceCount(&original_device_) == cudaSuccess) { + ASSERT_EQ(cudaGetDevice(&original_device_), cudaSuccess); + } + } + + int original_device_ = 0; +}; + +TEST_F(AOTITorchCUDAGuardTest, CreateAndDeleteCUDAGuard) { + CUDAGuardHandle guard = nullptr; + AOTITorchError error = aoti_torch_create_cuda_guard(0, &guard); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(guard, nullptr); + + int current_device = -1; + ASSERT_EQ(cudaGetDevice(¤t_device), cudaSuccess); + EXPECT_EQ(current_device, 0); + + error = aoti_torch_delete_cuda_guard(guard); + EXPECT_EQ(error, Error::Ok); +} + +TEST_F(AOTITorchCUDAGuardTest, CreateCUDAGuardNullReturnPointer) { + AOTITorchError error = aoti_torch_create_cuda_guard(0, nullptr); + EXPECT_EQ(error, Error::InvalidArgument); +} + +TEST_F(AOTITorchCUDAGuardTest, DeleteCUDAGuardNullHandle) { + AOTITorchError error = aoti_torch_delete_cuda_guard(nullptr); + EXPECT_EQ(error, Error::InvalidArgument); +} + +TEST_F(AOTITorchCUDAGuardTest, CUDAGuardSetIndexNullHandle) { + AOTITorchError error = aoti_torch_cuda_guard_set_index(nullptr, 0); + EXPECT_EQ(error, Error::InvalidArgument); +} + +TEST_F(AOTITorchCUDAGuardTest, CUDAGuardSetIndexInvalidDevice) { + CUDAGuardHandle guard = nullptr; + AOTITorchError error = aoti_torch_create_cuda_guard(0, &guard); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(guard, nullptr); + + error = aoti_torch_cuda_guard_set_index(guard, 999); + EXPECT_NE(error, Error::Ok); + + error = aoti_torch_delete_cuda_guard(guard); + EXPECT_EQ(error, Error::Ok); +} + +TEST_F(AOTITorchCUDAGuardTest, CreateAndDeleteCUDAStreamGuard) { + cudaStream_t stream; + ASSERT_EQ(cudaStreamCreate(&stream), cudaSuccess); + + CUDAStreamGuardHandle guard = nullptr; + AOTITorchError error = aoti_torch_create_cuda_stream_guard(stream, 0, &guard); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(guard, nullptr); + + error = aoti_torch_delete_cuda_stream_guard(guard); + EXPECT_EQ(error, Error::Ok); + + ASSERT_EQ(cudaStreamDestroy(stream), cudaSuccess); +} + +TEST_F(AOTITorchCUDAGuardTest, CreateCUDAStreamGuardNullReturnPointer) { + cudaStream_t stream; + ASSERT_EQ(cudaStreamCreate(&stream), cudaSuccess); + + AOTITorchError error = + aoti_torch_create_cuda_stream_guard(stream, 0, nullptr); + EXPECT_EQ(error, Error::InvalidArgument); + + ASSERT_EQ(cudaStreamDestroy(stream), cudaSuccess); +} + +TEST_F(AOTITorchCUDAGuardTest, CreateCUDAStreamGuardNullStream) { + CUDAStreamGuardHandle guard = nullptr; + AOTITorchError error = + aoti_torch_create_cuda_stream_guard(nullptr, 0, &guard); + EXPECT_EQ(error, Error::InvalidArgument); +} + +TEST_F(AOTITorchCUDAGuardTest, DeleteCUDAStreamGuardNullHandle) { + AOTITorchError error = aoti_torch_delete_cuda_stream_guard(nullptr); + EXPECT_EQ(error, Error::InvalidArgument); +} + +TEST_F(AOTITorchCUDAGuardTest, GetCurrentCUDAStream) { + void* ret_stream = nullptr; + AOTITorchError error = aoti_torch_get_current_cuda_stream(0, &ret_stream); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(ret_stream, nullptr); +} + +TEST_F(AOTITorchCUDAGuardTest, GetCurrentCUDAStreamNullReturnPointer) { + AOTITorchError error = aoti_torch_get_current_cuda_stream(0, nullptr); + EXPECT_EQ(error, Error::InvalidArgument); +} + +TEST_F(AOTITorchCUDAGuardTest, StreamGuardWithSameDevice) { + ASSERT_EQ(cudaSetDevice(0), cudaSuccess); + + cudaStream_t stream1, stream2; + ASSERT_EQ(cudaStreamCreate(&stream1), cudaSuccess); + ASSERT_EQ(cudaStreamCreate(&stream2), cudaSuccess); + + CUDAStreamGuardHandle guard1 = nullptr; + AOTITorchError error = + aoti_torch_create_cuda_stream_guard(stream1, 0, &guard1); + EXPECT_EQ(error, Error::Ok); + + void* ret_stream = nullptr; + error = aoti_torch_get_current_cuda_stream(0, &ret_stream); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(static_cast(ret_stream), stream1); + + CUDAStreamGuardHandle guard2 = nullptr; + error = aoti_torch_create_cuda_stream_guard(stream2, 0, &guard2); + EXPECT_EQ(error, Error::Ok); + + ret_stream = nullptr; + error = aoti_torch_get_current_cuda_stream(0, &ret_stream); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(static_cast(ret_stream), stream2); + + error = aoti_torch_delete_cuda_stream_guard(guard2); + EXPECT_EQ(error, Error::Ok); + + ret_stream = nullptr; + error = aoti_torch_get_current_cuda_stream(0, &ret_stream); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(static_cast(ret_stream), stream1); + + error = aoti_torch_delete_cuda_stream_guard(guard1); + EXPECT_EQ(error, Error::Ok); + + ASSERT_EQ(cudaStreamDestroy(stream1), cudaSuccess); + ASSERT_EQ(cudaStreamDestroy(stream2), cudaSuccess); +} + +TEST_F(AOTITorchCUDAGuardTest, GetCurrentStreamAfterSetStream) { + cudaStream_t new_stream; + ASSERT_EQ(cudaStreamCreate(&new_stream), cudaSuccess); + + CUDAStreamGuardHandle guard = nullptr; + AOTITorchError error = + aoti_torch_create_cuda_stream_guard(new_stream, 0, &guard); + EXPECT_EQ(error, Error::Ok); + + void* ret_stream = nullptr; + error = aoti_torch_get_current_cuda_stream(0, &ret_stream); + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(static_cast(ret_stream), new_stream); + + error = aoti_torch_delete_cuda_stream_guard(guard); + EXPECT_EQ(error, Error::Ok); + + ASSERT_EQ(cudaStreamDestroy(new_stream), cudaSuccess); +} diff --git a/runtime/platform/log.h b/runtime/platform/log.h index 72ea8528442..7293fa2428d 100644 --- a/runtime/platform/log.h +++ b/runtime/platform/log.h @@ -181,6 +181,20 @@ using ::executorch::runtime::LogLevel; ##__VA_ARGS__); \ } \ } while (0) + +/** + * Check a condition and log an error message if the condition is false. + * + * @param[in] _condition The condition to check. + * @param[in] _format Log message format string. + */ +#define ET_CHECK_OR_LOG_ERROR(_condition, _format, ...) \ + do { \ + if (!(_condition)) { \ + ET_LOG(Error, _format, ##__VA_ARGS__); \ + } \ + } while (0) + #else // ET_LOG_ENABLED /** @@ -191,4 +205,12 @@ using ::executorch::runtime::LogLevel; */ #define ET_LOG(_level, _format, ...) ((void)0) +/** + * Check a condition and log an error message if the condition is false. + * + * @param[in] _condition The condition to check. + * @param[in] _format Log message format string. + */ +#define ET_CHECK_OR_LOG_ERROR(_condition, _format, ...) ((void)0) + #endif // ET_LOG_ENABLED