Skip to content

Commit 056ccb9

Browse files
committed
introduce shim layers for cudaguard and cudastreamguard
Pull Request resolved: #14902 ### Summary This diff introduces shim layers for CudaGuard and CudaStreamGuard in the Executorch project, which will be further used by cuda-AOTI models for stream/cuda control. The changes include: * Adding a new source file `runtime/shims/cuda_guard.cpp` and header file `runtime/shims/cuda_guard.h` to the `CMakeLists.txt` and `TARGETS` files. * Creating a new test target `aoti_torch_cuda_guard` in the `targets.bzl` file. * Defining the `cuda_guard.h` header file with the necessary includes, namespace definitions, and function declarations. These changes aim to provide a shim layer for CudaGuard, which is responsible for handling CUDA-related functionality in the Executorch runtime. The shim layer will allow for better modularity and maintainability of the codebase. ghstack-source-id: 314984327 @exported-using-ghexport Differential Revision: [D84126634](https://our.internmc.facebook.com/intern/diff/D84126634/)
1 parent f32e9fc commit 056ccb9

File tree

10 files changed

+464
-10
lines changed

10 files changed

+464
-10
lines changed

backends/aoti/CMakeLists.txt

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,7 @@ target_compile_options(aoti_common PUBLIC -fexceptions -frtti -fPIC)
4141
target_link_options(aoti_common PUBLIC -Wl,--export-dynamic)
4242

4343
# Link against PyTorch libraries and standard libraries
44-
target_link_libraries(
45-
aoti_common
46-
PUBLIC extension_tensor ${CMAKE_DL_LIBS}
47-
# Link PyTorch libraries for AOTI functions
48-
${TORCH_LIBRARIES}
49-
)
44+
target_link_libraries(aoti_common PUBLIC extension_tensor ${CMAKE_DL_LIBS})
5045
executorch_target_link_options_shared_lib(aoti_common)
5146

5247
install(

backends/aoti/aoti_model_container.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ struct AOTIDelegateHandle {
7777
void* so_handle;
7878
std::string so_path;
7979
AOTInductorModelContainerHandle container_handle;
80+
void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header
81+
// dependency
8082
};
8183

8284
} // namespace aoti

backends/cuda/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
3535
find_package_torch()
3636

3737
# CUDA-specific AOTI functionality
38-
set(_aoti_cuda_sources runtime/cuda_backend.cpp runtime/shims/memory.cpp
39-
runtime/shims/tensor_attribute.cpp runtime/guard.cpp
38+
set(_aoti_cuda_sources
39+
runtime/cuda_backend.cpp runtime/shims/memory.cpp
40+
runtime/shims/tensor_attribute.cpp runtime/guard.cpp
41+
runtime/shims/cuda_guard.cpp
4042
)
4143
add_library(aoti_cuda STATIC ${_aoti_cuda_sources})
4244
target_include_directories(

backends/cuda/runtime/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ runtime.cxx_library(
66
name = "runtime_shims",
77
srcs = [
88
"guard.cpp",
9+
"shims/cuda_guard.cpp",
910
"shims/memory.cpp",
1011
"shims/tensor_attribute.cpp",
1112
],
1213
headers = [
1314
"guard.h",
15+
"shims/cuda_guard.h",
1416
"shims/memory.h",
1517
"shims/tensor_attribute.h",
1618
"utils.h",

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <cuda_runtime.h>
910
#include <dlfcn.h>
1011
#include <executorch/runtime/backend/interface.h>
1112
#include <executorch/runtime/core/error.h>
@@ -16,14 +17,14 @@
1617

1718
#include <filesystem>
1819
#include <fstream>
19-
#include <iostream>
2020
#include <string>
2121
#include <vector>
2222

2323
// Include our shim layer headers
2424
#include <executorch/backends/aoti/aoti_model_container.h>
2525
#include <executorch/backends/aoti/common_shims.h>
2626
#include <executorch/backends/cuda/runtime/shims/memory.h>
27+
#include <executorch/backends/cuda/runtime/utils.h>
2728

2829
namespace executorch {
2930
namespace backends {
@@ -182,6 +183,12 @@ class ET_EXPERIMENTAL CudaBackend final
182183
handle->so_handle = so_handle;
183184
handle->so_path = so_path.string();
184185
handle->container_handle = container_handle;
186+
187+
// Create a CUDA stream for asynchronous execution
188+
cudaStream_t cuda_stream;
189+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&cuda_stream));
190+
handle->cuda_stream = static_cast<void*>(cuda_stream);
191+
185192
return (DelegateHandle*)handle; // Return the handle post-processing
186193
}
187194

@@ -288,7 +295,7 @@ class ET_EXPERIMENTAL CudaBackend final
288295
n_inputs,
289296
gpu_outputs.data(), // Use GPU output tensors
290297
n_outputs,
291-
nullptr, // Pass the actual CUDA stream!
298+
handle->cuda_stream, // Pass the actual CUDA stream
292299
nullptr); // proxy_executor_handle can remain nullptr
293300

294301
if (error != Error::Ok) {
@@ -334,6 +341,17 @@ class ET_EXPERIMENTAL CudaBackend final
334341
}
335342
AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_;
336343

344+
// Destroy the CUDA stream if it exists
345+
if (handle->cuda_stream != nullptr) {
346+
cudaStream_t cuda_stream = static_cast<cudaStream_t>(handle->cuda_stream);
347+
cudaError_t stream_err = cudaStreamDestroy(cuda_stream);
348+
ET_CHECK_OR_LOG_ERROR(
349+
stream_err == cudaSuccess,
350+
"Failed to destroy CUDA stream: %s",
351+
cudaGetErrorString(stream_err));
352+
handle->cuda_stream = nullptr;
353+
}
354+
337355
// Delete the container BEFORE closing the shared library
338356
if (handle->container_handle != nullptr) {
339357
AOTIRuntimeError delete_result =
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cuda/runtime/shims/cuda_guard.h>
10+
11+
namespace executorch {
12+
namespace backends {
13+
namespace cuda {
14+
15+
extern "C" {
16+
17+
AOTITorchError aoti_torch_create_cuda_guard(
18+
int32_t device_index,
19+
CUDAGuardHandle* ret_guard) {
20+
ET_CHECK_OR_RETURN_ERROR(
21+
ret_guard != nullptr,
22+
InvalidArgument,
23+
"aoti_torch_create_cuda_guard failed: ret_guard is null");
24+
25+
auto result = CUDAGuard::create(device_index);
26+
if (!result.ok()) {
27+
return result.error();
28+
}
29+
*ret_guard = new CUDAGuard(std::move(result.get()));
30+
return Error::Ok;
31+
}
32+
33+
AOTITorchError aoti_torch_delete_cuda_guard(CUDAGuardHandle guard) {
34+
ET_CHECK_OR_RETURN_ERROR(
35+
guard != nullptr,
36+
InvalidArgument,
37+
"aoti_torch_delete_cuda_guard failed: guard is null");
38+
39+
delete guard;
40+
return Error::Ok;
41+
}
42+
43+
AOTITorchError aoti_torch_cuda_guard_set_index(
44+
CUDAGuardHandle guard,
45+
int32_t device_index) {
46+
ET_CHECK_OR_RETURN_ERROR(
47+
guard != nullptr,
48+
InvalidArgument,
49+
"aoti_torch_cuda_guard_set_index failed: guard is null");
50+
51+
ET_CHECK_OK_OR_RETURN_ERROR(guard->set_index(device_index));
52+
return Error::Ok;
53+
}
54+
55+
AOTITorchError aoti_torch_create_cuda_stream_guard(
56+
void* stream,
57+
int32_t device_index,
58+
CUDAStreamGuardHandle* ret_guard) {
59+
ET_CHECK_OR_RETURN_ERROR(
60+
ret_guard != nullptr,
61+
InvalidArgument,
62+
"aoti_torch_create_cuda_stream_guard failed: ret_guard is null");
63+
64+
ET_CHECK_OR_RETURN_ERROR(
65+
stream != nullptr,
66+
InvalidArgument,
67+
"aoti_torch_create_cuda_stream_guard failed: stream is null");
68+
69+
auto result =
70+
CUDAStreamGuard::create(static_cast<cudaStream_t>(stream), device_index);
71+
if (!result.ok()) {
72+
return result.error();
73+
}
74+
*ret_guard = new CUDAStreamGuard(std::move(result.get()));
75+
return Error::Ok;
76+
}
77+
78+
AOTITorchError aoti_torch_delete_cuda_stream_guard(
79+
CUDAStreamGuardHandle guard) {
80+
ET_CHECK_OR_RETURN_ERROR(
81+
guard != nullptr,
82+
InvalidArgument,
83+
"aoti_torch_delete_cuda_stream_guard failed: guard is null");
84+
85+
delete guard;
86+
return Error::Ok;
87+
}
88+
89+
AOTITorchError aoti_torch_get_current_cuda_stream(
90+
int32_t device_index,
91+
void** ret_stream) {
92+
ET_CHECK_OR_RETURN_ERROR(
93+
ret_stream != nullptr,
94+
InvalidArgument,
95+
"aoti_torch_get_current_cuda_stream failed: ret_stream is null");
96+
97+
auto result = getCurrentCUDAStream(device_index);
98+
if (!result.ok()) {
99+
return result.error();
100+
}
101+
*ret_stream = static_cast<void*>(result.get());
102+
return Error::Ok;
103+
}
104+
105+
} // extern "C"
106+
107+
} // namespace cuda
108+
} // namespace backends
109+
} // namespace executorch
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cuda_runtime.h>
12+
#include <executorch/backends/aoti/common_shims.h>
13+
#include <executorch/backends/cuda/runtime/guard.h>
14+
#include <cstdint>
15+
16+
namespace executorch {
17+
namespace backends {
18+
namespace cuda {
19+
20+
using executorch::backends::aoti::AOTITorchError;
21+
22+
extern "C" {
23+
24+
// Handle types for CUDA guards
25+
using CUDAGuardHandle = CUDAGuard*;
26+
using CUDAStreamGuardHandle = CUDAStreamGuard*;
27+
28+
/**
29+
* Creates a CUDA device guard that sets the current device and restores it
30+
* upon destruction.
31+
*
32+
* @param device_index The device index to set as current
33+
* @param ret_guard Output parameter for the created guard handle (must not be
34+
* null)
35+
* @return AOTITorchError error code (Error::Ok on success, or an error code on
36+
* failure)
37+
*/
38+
AOTITorchError aoti_torch_create_cuda_guard(
39+
int32_t device_index,
40+
CUDAGuardHandle* ret_guard);
41+
42+
/**
43+
* Deletes a CUDA device guard and frees its associated resources.
44+
*
45+
* @param guard Handle to the guard to be deleted
46+
* @return AOTITorchError error code (Error::Ok on success, or an error code on
47+
* failure)
48+
*/
49+
AOTITorchError aoti_torch_delete_cuda_guard(CUDAGuardHandle guard);
50+
51+
/**
52+
* Sets the CUDA device to a new index for an existing guard.
53+
*
54+
* @param guard Handle to the guard
55+
* @param device_index The device index to set as current
56+
* @return AOTITorchError error code (Error::Ok on success, or an error code on
57+
* failure)
58+
*/
59+
AOTITorchError aoti_torch_cuda_guard_set_index(
60+
CUDAGuardHandle guard,
61+
int32_t device_index);
62+
63+
/**
64+
* Creates a CUDA stream guard that sets the current device and stream,
65+
* restoring both upon destruction.
66+
*
67+
* @param stream The CUDA stream to set as current
68+
* @param device_index The device index for the stream
69+
* @param ret_guard Output parameter for the created guard handle (must not be
70+
* null)
71+
* @return AOTITorchError error code (Error::Ok on success, or an error code on
72+
* failure)
73+
*/
74+
AOTITorchError aoti_torch_create_cuda_stream_guard(
75+
void* stream,
76+
int32_t device_index,
77+
CUDAStreamGuardHandle* ret_guard);
78+
79+
/**
80+
* Deletes a CUDA stream guard and frees its associated resources.
81+
*
82+
* @param guard Handle to the stream guard to be deleted
83+
* @return AOTITorchError error code (Error::Ok on success, or an error code on
84+
* failure)
85+
*/
86+
AOTITorchError aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard);
87+
88+
/**
89+
* Gets the current CUDA stream for a specified device.
90+
*
91+
* @param device_index The device index (-1 to use current device)
92+
* @param ret_stream Output parameter for the current stream (must not be null)
93+
* @return AOTITorchError error code (Error::Ok on success, or an error code on
94+
* failure)
95+
*/
96+
AOTITorchError aoti_torch_get_current_cuda_stream(
97+
int32_t device_index,
98+
void** ret_stream);
99+
100+
} // extern "C"
101+
102+
} // namespace cuda
103+
} // namespace backends
104+
} // namespace executorch

backends/cuda/runtime/shims/tests/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@ def define_common_targets():
3232
cuda_shim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2")
3333
cuda_shim_cpp_unittest("aoti_torch__reinterpret_tensor")
3434
cuda_shim_cpp_unittest("aoti_torch_copy_")
35+
cuda_shim_cpp_unittest("aoti_torch_cuda_guard")

0 commit comments

Comments
 (0)