Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/aoti/aoti_model_container.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ struct AOTIDelegateHandle {
void* so_handle;
std::string so_path;
AOTInductorModelContainerHandle container_handle;
void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header
// dependency
};

} // namespace aoti
Expand Down
22 changes: 20 additions & 2 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include <cuda_runtime.h>
#include <dlfcn.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/core/error.h>
Expand All @@ -16,14 +17,14 @@

#include <filesystem>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>

// Include our shim layer headers
#include <executorch/backends/aoti/aoti_model_container.h>
#include <executorch/backends/aoti/common_shims.h>
#include <executorch/backends/cuda/runtime/shims/memory.h>
#include <executorch/backends/cuda/runtime/utils.h>

namespace executorch {
namespace backends {
Expand Down Expand Up @@ -182,6 +183,12 @@ class ET_EXPERIMENTAL CudaBackend final
handle->so_handle = so_handle;
handle->so_path = so_path.string();
handle->container_handle = container_handle;

// Create a CUDA stream for asynchronous execution
cudaStream_t cuda_stream;
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&cuda_stream));
handle->cuda_stream = static_cast<void*>(cuda_stream);

return (DelegateHandle*)handle; // Return the handle post-processing
}

Expand Down Expand Up @@ -288,7 +295,7 @@ class ET_EXPERIMENTAL CudaBackend final
n_inputs,
gpu_outputs.data(), // Use GPU output tensors
n_outputs,
nullptr, // Pass the actual CUDA stream!
handle->cuda_stream, // Pass the actual CUDA stream
nullptr); // proxy_executor_handle can remain nullptr

if (error != Error::Ok) {
Expand Down Expand Up @@ -334,6 +341,17 @@ class ET_EXPERIMENTAL CudaBackend final
}
AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_;

// Destroy the CUDA stream if it exists
if (handle->cuda_stream != nullptr) {
cudaStream_t cuda_stream = static_cast<cudaStream_t>(handle->cuda_stream);
cudaError_t stream_err = cudaStreamDestroy(cuda_stream);
ET_CHECK_OR_LOG(
stream_err == cudaSuccess,
"Failed to destroy CUDA stream: %s",
cudaGetErrorString(stream_err));
handle->cuda_stream = nullptr;
}

// Delete the container BEFORE closing the shared library
if (handle->container_handle != nullptr) {
AOTIRuntimeError delete_result =
Expand Down
22 changes: 22 additions & 0 deletions runtime/platform/log.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,20 @@ using ::executorch::runtime::LogLevel;
##__VA_ARGS__); \
} \
} while (0)

/**
* Check a condition and log an error message if the condition is false.
*
* @param[in] _condition The condition to check.
* @param[in] _format Log message format string.
*/
#define ET_CHECK_OR_LOG(_condition, _format, ...) \
do { \
if (!(_condition)) { \
ET_LOG(Error, _format, ##__VA_ARGS__); \
} \
} while (0)

Comment on lines 184 to 197
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it explicit that we are logging at Error level. Maybe ET_CHECK_OR_LOG_ERROR? The other thing you can do is let this macro take a logging level.

#else // ET_LOG_ENABLED

/**
Expand All @@ -191,4 +205,12 @@ using ::executorch::runtime::LogLevel;
*/
#define ET_LOG(_level, _format, ...) ((void)0)

/**
* Check a condition and log an error message if the condition is false.
*
* @param[in] _condition The condition to check.
* @param[in] _format Log message format string.
*/
#define ET_CHECK_OR_LOG(_condition, _format, ...) ((void)0)

#endif // ET_LOG_ENABLED
Loading