Skip to content

Commit baf697f

Browse files
committed
Update on "introduce cuda stream into runtime backend"
This diff introduces CUDA streams into the Executorch runtime backend. The changes include: * Adding CUDA stream support to the `cuda_backend.cpp` file * Including the `cuda_runtime.h` header file in `cuda_backend.cpp` * Adding a `void* cuda_stream` field to the `AOTInductorModelContainer` struct in `aoti_model_container.h` to store the CUDA stream * Defining a new macro `ET_CHECK_OR_LOG` in `log.h` to check a condition and log an error message if the condition is false. Differential Revision: [D84128173](https://our.internmc.facebook.com/intern/diff/D84128173/) [ghstack-poisoned]
2 parents e69db18 + 31ea976 commit baf697f

File tree

3 files changed

+8
-13
lines changed

3 files changed

+8
-13
lines changed

backends/aoti/CMakeLists.txt

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,7 @@ target_compile_options(aoti_common PUBLIC -fexceptions -frtti -fPIC)
4141
target_link_options(aoti_common PUBLIC -Wl,--export-dynamic)
4242

4343
# Link against PyTorch libraries and standard libraries
44-
target_link_libraries(
45-
aoti_common
46-
PUBLIC extension_tensor ${CMAKE_DL_LIBS}
47-
# Link PyTorch libraries for AOTI functions
48-
${TORCH_LIBRARIES}
49-
)
44+
target_link_libraries(aoti_common PUBLIC extension_tensor ${CMAKE_DL_LIBS})
5045
executorch_target_link_options_shared_lib(aoti_common)
5146

5247
install(

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ class ET_EXPERIMENTAL CudaBackend final
345345
if (handle->cuda_stream != nullptr) {
346346
cudaStream_t cuda_stream = static_cast<cudaStream_t>(handle->cuda_stream);
347347
cudaError_t stream_err = cudaStreamDestroy(cuda_stream);
348-
ET_CHECK_OR_LOG(
348+
ET_CHECK_OR_LOG_ERROR(
349349
stream_err == cudaSuccess,
350350
"Failed to destroy CUDA stream: %s",
351351
cudaGetErrorString(stream_err));

runtime/platform/log.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,11 @@ using ::executorch::runtime::LogLevel;
188188
* @param[in] _condition The condition to check.
189189
* @param[in] _format Log message format string.
190190
*/
191-
#define ET_CHECK_OR_LOG(_condition, _format, ...) \
192-
do { \
193-
if (!(_condition)) { \
194-
ET_LOG(Error, _format, ##__VA_ARGS__); \
195-
} \
191+
#define ET_CHECK_OR_LOG_ERROR(_condition, _format, ...) \
192+
do { \
193+
if (!(_condition)) { \
194+
ET_LOG(Error, _format, ##__VA_ARGS__); \
195+
} \
196196
} while (0)
197197

198198
#else // ET_LOG_ENABLED
@@ -211,6 +211,6 @@ using ::executorch::runtime::LogLevel;
211211
* @param[in] _condition The condition to check.
212212
* @param[in] _format Log message format string.
213213
*/
214-
#define ET_CHECK_OR_LOG(_condition, _format, ...) ((void)0)
214+
#define ET_CHECK_OR_LOG_ERROR(_condition, _format, ...) ((void)0)
215215

216216
#endif // ET_LOG_ENABLED

0 commit comments

Comments
 (0)