Skip to content

Commit 51094cc

Browse files
committed
Update on "introduce shim layers for cudaguard and cudastreamguard"
### Summary This diff introduces shim layers for CudaGuard and CudaStreamGuard in the Executorch project, which will be further used by cuda-AOTI models for stream/cuda control. The changes include: * Adding a new source file `runtime/shims/cuda_guard.cpp` and header file `runtime/shims/cuda_guard.h` to the `CMakeLists.txt` and `TARGETS` files. * Creating a new test target `aoti_torch_cuda_guard` in the `targets.bzl` file. * Defining the `cuda_guard.h` header file with the necessary includes, namespace definitions, and function declarations. These changes aim to provide a shim layer for CudaGuard, which is responsible for handling CUDA-related functionality in the Executorch runtime. The shim layer will allow for better modularity and maintainability of the codebase. Differential Revision: [D84126634](https://our.internmc.facebook.com/intern/diff/D84126634/) [ghstack-poisoned]
2 parents b23d9ef + c46daf6 commit 51094cc

File tree

3 files changed

+44
-2
lines changed

3 files changed

+44
-2
lines changed

backends/aoti/aoti_model_container.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ struct AOTIDelegateHandle {
7777
void* so_handle;
7878
std::string so_path;
7979
AOTInductorModelContainerHandle container_handle;
80+
void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header
81+
// dependency
8082
};
8183

8284
} // namespace aoti

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <cuda_runtime.h>
910
#include <dlfcn.h>
1011
#include <executorch/runtime/backend/interface.h>
1112
#include <executorch/runtime/core/error.h>
@@ -16,14 +17,14 @@
1617

1718
#include <filesystem>
1819
#include <fstream>
19-
#include <iostream>
2020
#include <string>
2121
#include <vector>
2222

2323
// Include our shim layer headers
2424
#include <executorch/backends/aoti/aoti_model_container.h>
2525
#include <executorch/backends/aoti/common_shims.h>
2626
#include <executorch/backends/cuda/runtime/shims/memory.h>
27+
#include <executorch/backends/cuda/runtime/utils.h>
2728

2829
namespace executorch {
2930
namespace backends {
@@ -182,6 +183,12 @@ class ET_EXPERIMENTAL CudaBackend final
182183
handle->so_handle = so_handle;
183184
handle->so_path = so_path.string();
184185
handle->container_handle = container_handle;
186+
187+
// Create a CUDA stream for asynchronous execution
188+
cudaStream_t cuda_stream;
189+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&cuda_stream));
190+
handle->cuda_stream = static_cast<void*>(cuda_stream);
191+
185192
return (DelegateHandle*)handle; // Return the handle post-processing
186193
}
187194

@@ -288,7 +295,7 @@ class ET_EXPERIMENTAL CudaBackend final
288295
n_inputs,
289296
gpu_outputs.data(), // Use GPU output tensors
290297
n_outputs,
291-
nullptr, // Pass the actual CUDA stream!
298+
handle->cuda_stream, // Pass the actual CUDA stream
292299
nullptr); // proxy_executor_handle can remain nullptr
293300

294301
if (error != Error::Ok) {
@@ -334,6 +341,17 @@ class ET_EXPERIMENTAL CudaBackend final
334341
}
335342
AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_;
336343

344+
// Destroy the CUDA stream if it exists
345+
if (handle->cuda_stream != nullptr) {
346+
cudaStream_t cuda_stream = static_cast<cudaStream_t>(handle->cuda_stream);
347+
cudaError_t stream_err = cudaStreamDestroy(cuda_stream);
348+
ET_CHECK_OR_LOG_ERROR(
349+
stream_err == cudaSuccess,
350+
"Failed to destroy CUDA stream: %s",
351+
cudaGetErrorString(stream_err));
352+
handle->cuda_stream = nullptr;
353+
}
354+
337355
// Delete the container BEFORE closing the shared library
338356
if (handle->container_handle != nullptr) {
339357
AOTIRuntimeError delete_result =

runtime/platform/log.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,20 @@ using ::executorch::runtime::LogLevel;
181181
##__VA_ARGS__); \
182182
} \
183183
} while (0)
184+
185+
/**
186+
* Check a condition and log an error message if the condition is false.
187+
*
188+
* @param[in] _condition The condition to check.
189+
* @param[in] _format Log message format string.
190+
*/
191+
#define ET_CHECK_OR_LOG_ERROR(_condition, _format, ...) \
192+
do { \
193+
if (!(_condition)) { \
194+
ET_LOG(Error, _format, ##__VA_ARGS__); \
195+
} \
196+
} while (0)
197+
184198
#else // ET_LOG_ENABLED
185199

186200
/**
@@ -191,4 +205,12 @@ using ::executorch::runtime::LogLevel;
191205
*/
192206
#define ET_LOG(_level, _format, ...) ((void)0)
193207

208+
/**
209+
* Check a condition and log an error message if the condition is false.
210+
*
211+
* @param[in] _condition The condition to check.
212+
* @param[in] _format Log message format string.
213+
*/
214+
#define ET_CHECK_OR_LOG_ERROR(_condition, _format, ...) ((void)0)
215+
194216
#endif // ET_LOG_ENABLED

0 commit comments

Comments
 (0)