Skip to content

Commit 0a7ac54

Browse files
committed
introduce CudaGuard and cudastreamguard
### Introduce CudaGuard and CudaStreamGuard This diff introduces `CudaGuard` and `CudaStreamGuard` in the Executorch CUDA runtime. These classes provide a convenient way to manage CUDA device and stream selection. #### Changes * Added `CudaGuard` and `CudaStreamGuard` classes in `fbcode/executorch/backends/cuda/runtime/guard.h`. * Implemented `CudaGuard` and `CudaStreamGuard` in `fbcode/executorch/backends/cuda/runtime/guard.cpp`. * Added unit tests for `CudaStreamGuard` in `fbcode/executorch/backends/cuda/runtime/tests/test_cuda_stream_guard.cpp`. * Updated `TARGETS` file to include the new files. #### Purpose The `CudaGuard` class provides a way to select a CUDA device and ensure that it is properly released when the guard goes out of scope. The `CudaStreamGuard` class provides a way to select a CUDA stream and ensure that it is properly synchronized when the guard goes out of scope. #### Usage They will be further used and controled by their shim layer functions. Differential Revision: [D84126481](https://our.internmc.facebook.com/intern/diff/D84126481/) ghstack-source-id: 314867936 Pull Request resolved: #14901
1 parent 7e2f75e commit 0a7ac54

File tree

8 files changed

+759
-1
lines changed

8 files changed

+759
-1
lines changed

backends/cuda/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ find_package_torch()
3636

3737
# CUDA-specific AOTI functionality
3838
set(_aoti_cuda_sources runtime/cuda_backend.cpp runtime/shims/memory.cpp
39-
runtime/shims/tensor_attribute.cpp
39+
runtime/shims/tensor_attribute.cpp runtime/guard.cpp
4040
)
4141
add_library(aoti_cuda STATIC ${_aoti_cuda_sources})
4242
target_include_directories(

backends/cuda/runtime/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@ oncall("executorch")
55
runtime.cxx_library(
66
name = "runtime_shims",
77
srcs = [
8+
"guard.cpp",
89
"shims/memory.cpp",
910
"shims/tensor_attribute.cpp",
1011
],
1112
headers = [
13+
"guard.h",
1214
"shims/memory.h",
1315
"shims/tensor_attribute.h",
1416
"utils.h",

backends/cuda/runtime/guard.cpp

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cuda/runtime/guard.h>
10+
#include <executorch/runtime/platform/log.h>
11+
12+
namespace executorch {
13+
namespace backends {
14+
namespace cuda {
15+
16+
namespace {
17+
// Thread-local stream storage (private to this file)
18+
thread_local std::unordered_map<DeviceIndex, cudaStream_t> current_streams_;
19+
} // namespace
20+
21+
Error setCurrentCUDAStream(cudaStream_t stream, DeviceIndex device_index) {
22+
if (device_index == -1) {
23+
// Get current device if not specified
24+
int current_device;
25+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(&current_device));
26+
device_index = current_device;
27+
}
28+
29+
current_streams_[device_index] = stream;
30+
return Error::Ok;
31+
}
32+
33+
Result<cudaStream_t> getCurrentCUDAStream(DeviceIndex device_index) {
34+
if (device_index == -1) {
35+
int current_device;
36+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(&current_device));
37+
device_index = current_device;
38+
}
39+
40+
auto it = current_streams_.find(device_index);
41+
if (it != current_streams_.end()) {
42+
return it->second;
43+
}
44+
45+
cudaStream_t stream;
46+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&stream));
47+
setCurrentCUDAStream(stream, device_index);
48+
return stream;
49+
}
50+
51+
CUDAGuard::CUDAGuard(CUDAGuard&& other) noexcept
52+
: original_device_index_(other.original_device_index_),
53+
current_device_index_(other.current_device_index_) {
54+
// Mark the moved-from object as "already restored" so its destructor doesn't
55+
// try to restore the device
56+
other.original_device_index_ = other.current_device_index_;
57+
}
58+
59+
CUDAGuard::~CUDAGuard() {
60+
if (original_device_index_ != current_device_index_) {
61+
cudaError_t err = cudaSetDevice(original_device_index_);
62+
if (err != cudaSuccess) {
63+
ET_LOG(
64+
Error,
65+
"~CUDAGuard: Failed to restore device to %d: %s",
66+
original_device_index_,
67+
cudaGetErrorString(err));
68+
}
69+
}
70+
}
71+
72+
Error CUDAGuard::set_index(DeviceIndex device_index) {
73+
int orig_index = -1;
74+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(&orig_index));
75+
76+
original_device_index_ = orig_index;
77+
current_device_index_ = device_index;
78+
79+
if (current_device_index_ != original_device_index_) {
80+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaSetDevice(current_device_index_));
81+
}
82+
83+
return Error::Ok;
84+
}
85+
86+
Result<CUDAGuard> CUDAGuard::create(DeviceIndex device_index) {
87+
CUDAGuard guard; // Fixed: Removed () to create a variable, not a function
88+
ET_CHECK_OK_OR_RETURN_ERROR(guard.set_index(device_index));
89+
return guard;
90+
}
91+
92+
CUDAStreamGuard::CUDAStreamGuard(CUDAStreamGuard&& other) noexcept
93+
: device_guard_(std::move(other.device_guard_)),
94+
original_stream_(other.original_stream_),
95+
current_stream_(other.current_stream_),
96+
device_index_(other.device_index_) {
97+
// Mark the moved-from object as "already restored" so its destructor doesn't
98+
// try to restore the stream
99+
other.original_stream_ = other.current_stream_;
100+
}
101+
102+
CUDAStreamGuard::~CUDAStreamGuard() {
103+
// Restore the original stream unless this object was moved-from.
104+
// After a move, original_stream_ == current_stream_, which indicates
105+
// the moved-from object should not restore.
106+
// Note: nullptr is a valid stream value (represents the default stream),
107+
// so we must restore even if original_stream_ is nullptr.
108+
if (original_stream_ != current_stream_) {
109+
Error err = setCurrentCUDAStream(original_stream_, device_index_);
110+
if (err != Error::Ok) {
111+
ET_LOG(
112+
Error,
113+
"~CUDAStreamGuard: Failed to restore stream for device %d",
114+
device_index_);
115+
}
116+
}
117+
}
118+
119+
Error CUDAStreamGuard::set_stream(
120+
cudaStream_t stream,
121+
DeviceIndex device_index) {
122+
auto result = getCurrentCUDAStream(device_index);
123+
if (!result.ok()) {
124+
ET_LOG(Error, "Failed to get current stream for device %d", device_index);
125+
return result.error();
126+
}
127+
128+
original_stream_ = result.get();
129+
current_stream_ = stream;
130+
device_index_ = device_index;
131+
132+
ET_CHECK_OK_OR_RETURN_ERROR(setCurrentCUDAStream(stream, device_index));
133+
134+
return Error::Ok;
135+
}
136+
137+
Result<CUDAStreamGuard> CUDAStreamGuard::create(
138+
cudaStream_t stream,
139+
DeviceIndex device_index) {
140+
auto guard_result = CUDAGuard::create(device_index);
141+
ET_CHECK_OK_OR_RETURN_ERROR(guard_result.error());
142+
143+
CUDAStreamGuard stream_guard(std::move(guard_result.get()));
144+
ET_CHECK_OK_OR_RETURN_ERROR(stream_guard.set_stream(stream, device_index));
145+
146+
return stream_guard;
147+
}
148+
149+
} // namespace cuda
150+
} // namespace backends
151+
} // namespace executorch

backends/cuda/runtime/guard.h

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cuda_runtime.h>
12+
#include <executorch/backends/cuda/runtime/utils.h>
13+
#include <executorch/runtime/core/error.h>
14+
#include <executorch/runtime/core/result.h>
15+
#include <cstdint>
16+
17+
namespace executorch {
18+
namespace backends {
19+
namespace cuda {
20+
21+
using executorch::runtime::Error;
22+
using executorch::runtime::Result;
23+
24+
// Type alias for device index
25+
using DeviceIndex = int32_t;
26+
27+
/**
28+
* Set the current CUDA stream for the specified device.
29+
*
30+
* @param stream The CUDA stream to set as current
31+
* @param device_index The device index (-1 to use current device)
32+
* @return Error code indicating success or failure
33+
*/
34+
Error setCurrentCUDAStream(cudaStream_t stream, DeviceIndex device_index = -1);
35+
36+
/**
37+
* Get the current CUDA stream for the specified device.
38+
* If no stream has been set, creates a new stream and sets it as current.
39+
*
40+
* @param device_index The device index (-1 to use current device)
41+
* @return Result containing the current stream on success, or an error code on
42+
* failure
43+
*/
44+
Result<cudaStream_t> getCurrentCUDAStream(DeviceIndex device_index = -1);
45+
46+
/**
47+
* RAII guard that sets the current CUDA device and restores it on destruction.
48+
* This ensures that the device is properly restored even if an exception
49+
* occurs.
50+
*
51+
*/
52+
class CUDAGuard {
53+
private:
54+
/**
55+
* Private constructor - use create() factory method instead.
56+
*/
57+
explicit CUDAGuard()
58+
: original_device_index_(-1), current_device_index_(-1) {}
59+
60+
public:
61+
/**
62+
* Factory method to create a CUDAGuard.
63+
*
64+
* @param device_index The device index to set as current
65+
* @return Result containing the guard on success, or an error code on failure
66+
*/
67+
static Result<CUDAGuard> create(DeviceIndex device_index);
68+
69+
// Copy is not allowed
70+
CUDAGuard(const CUDAGuard&) = delete;
71+
CUDAGuard& operator=(const CUDAGuard&) = delete;
72+
73+
// Move constructor and assignment
74+
CUDAGuard(CUDAGuard&& other) noexcept;
75+
CUDAGuard& operator=(CUDAGuard&& other) = delete;
76+
77+
/**
78+
* Destructor that restores the original device if necessary.
79+
*/
80+
~CUDAGuard();
81+
82+
/**
83+
* Sets the CUDA device to the given device index.
84+
*
85+
* @param device_index The device index to set as current
86+
* @return Error code indicating success or failure
87+
*/
88+
Error set_index(DeviceIndex device_index);
89+
90+
/**
91+
* Get the original device index before the guard was created.
92+
*
93+
* @return The original device index
94+
*/
95+
DeviceIndex original_device() const {
96+
return original_device_index_;
97+
}
98+
99+
/**
100+
* Get the current device index.
101+
*
102+
* @return The current device index
103+
*/
104+
DeviceIndex current_device() const {
105+
return current_device_index_;
106+
}
107+
108+
private:
109+
/// The original device before this guard was created
110+
DeviceIndex original_device_index_;
111+
/// The current device managed by this guard
112+
DeviceIndex current_device_index_;
113+
};
114+
115+
/**
116+
* RAII guard that sets the current CUDA device and stream, restoring both on
117+
* destruction. This is useful for temporarily switching to a different device
118+
* and stream.
119+
*
120+
*/
121+
class CUDAStreamGuard {
122+
private:
123+
// Private constructor that takes a CUDAGuard
124+
explicit CUDAStreamGuard(CUDAGuard&& guard)
125+
: device_guard_(std::move(guard)),
126+
original_stream_(nullptr),
127+
current_stream_(nullptr),
128+
device_index_(-1) {}
129+
130+
public:
131+
/**
132+
* Factory method to create a CUDAStreamGuard.
133+
*
134+
* @param stream The CUDA stream to set as current
135+
* @param device_index The device index for the stream
136+
* @return Result containing the guard on success, or an error code on failure
137+
*/
138+
static Result<CUDAStreamGuard> create(
139+
cudaStream_t stream,
140+
DeviceIndex device_index);
141+
142+
// Copy is not allowed
143+
CUDAStreamGuard(const CUDAStreamGuard&) = delete;
144+
CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete;
145+
146+
// Move constructor and assignment
147+
CUDAStreamGuard(CUDAStreamGuard&& other) noexcept;
148+
CUDAStreamGuard& operator=(CUDAStreamGuard&& other) noexcept = delete;
149+
150+
/**
151+
* Destructor that restores the original stream and device.
152+
*/
153+
~CUDAStreamGuard();
154+
155+
/**
156+
* Sets the CUDA stream to the given stream on the specified device.
157+
*
158+
* @param stream The CUDA stream to set as current
159+
* @param device_index The device index for the stream
160+
* @return Error code indicating success or failure
161+
*/
162+
Error set_stream(cudaStream_t stream, DeviceIndex device_index);
163+
164+
/**
165+
* Get the current guarded stream.
166+
*
167+
* @return The current stream
168+
*/
169+
cudaStream_t stream() const {
170+
return current_stream_;
171+
}
172+
173+
/**
174+
* Get the device index being guarded.
175+
*
176+
* @return The device index
177+
*/
178+
DeviceIndex device_index() const {
179+
return device_index_;
180+
}
181+
182+
private:
183+
/// The device guard that handles device switching
184+
CUDAGuard device_guard_;
185+
/// The original stream that was current before this guard
186+
cudaStream_t original_stream_ = nullptr;
187+
/// The current stream being guarded
188+
cudaStream_t current_stream_ = nullptr;
189+
/// The device index for this stream guard
190+
DeviceIndex device_index_;
191+
};
192+
193+
} // namespace cuda
194+
} // namespace backends
195+
} // namespace executorch
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest")
2+
load(":targets.bzl", "define_common_targets")
3+
4+
oncall("executorch")
5+
6+
define_common_targets()

0 commit comments

Comments
 (0)