Skip to content

Commit 334b232

Browse files
chhwangCopilot
andauthored
Fix GpuStreamPool to be aware of the device ID of streams (#590)
Co-authored-by: Copilot <[email protected]>
1 parent c580e4c commit 334b232

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

include/mscclpp/gpu_utils.hpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#define MSCCLPP_GPU_UTILS_HPP_
66

77
#include <memory>
8+
#include <unordered_map>
89

910
#include "env.hpp"
1011
#include "errors.hpp"
@@ -47,7 +48,7 @@ struct AvoidCudaGraphCaptureGuard {
4748
struct CudaStreamWithFlags {
4849
/// Constructor without flags. This will not create any stream. set() can be called later to create a stream with
4950
/// specified flags.
50-
CudaStreamWithFlags() : stream_(nullptr) {}
51+
CudaStreamWithFlags();
5152

5253
/// Constructor with flags. This will create a stream with the specified flags on the current device.
5354
/// @param flags The flags to create the stream with.
@@ -56,8 +57,8 @@ struct CudaStreamWithFlags {
5657
/// Destructor. This will destroy the stream if it was created.
5758
~CudaStreamWithFlags();
5859

59-
/// Set the stream with the specified flags. If the stream was already created, it will raise an error with
60-
/// ErrorCode::InvalidUsage.
60+
/// Set the stream with the specified flags. The current device at the time of the construction will be used. If the
61+
/// stream was already created, it will raise an error with ErrorCode::InvalidUsage.
6162
/// @param flags The flags to create the stream with.
6263
/// @throws Error if the stream was already created.
6364
void set(unsigned int flags);
@@ -68,7 +69,10 @@ struct CudaStreamWithFlags {
6869

6970
operator cudaStream_t() const { return stream_; }
7071

72+
int deviceId() const { return deviceId_; }
73+
7174
cudaStream_t stream_;
75+
int deviceId_;
7276
};
7377

7478
class GpuStreamPool;
@@ -89,6 +93,8 @@ class GpuStream {
8993

9094
operator cudaStream_t() const { return stream_->stream_; }
9195

96+
int deviceId() const { return stream_->deviceId(); }
97+
9298
private:
9399
friend class GpuStreamPool;
94100

@@ -111,7 +117,7 @@ class GpuStreamPool {
111117

112118
protected:
113119
friend class GpuStream;
114-
std::vector<std::shared_ptr<CudaStreamWithFlags>> streams_;
120+
std::unordered_map<int, std::vector<std::shared_ptr<CudaStreamWithFlags>>> streams_;
115121
};
116122

117123
/// Get the singleton instance of GpuStreamPool.

src/gpu_utils.cc

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ AvoidCudaGraphCaptureGuard::AvoidCudaGraphCaptureGuard() : mode_(cudaStreamCaptu
1313

1414
AvoidCudaGraphCaptureGuard::~AvoidCudaGraphCaptureGuard() { (void)cudaThreadExchangeStreamCaptureMode(&mode_); }
1515

16+
CudaStreamWithFlags::CudaStreamWithFlags() : stream_(nullptr) { MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId_)); }
17+
1618
CudaStreamWithFlags::CudaStreamWithFlags(unsigned int flags) {
19+
MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId_));
1720
MSCCLPP_CUDATHROW(cudaStreamCreateWithFlags(&stream_, flags));
1821
}
1922

@@ -23,22 +26,29 @@ CudaStreamWithFlags::~CudaStreamWithFlags() {
2326

2427
void CudaStreamWithFlags::set(unsigned int flags) {
2528
if (!empty()) throw Error("CudaStreamWithFlags already set", ErrorCode::InvalidUsage);
29+
int originalDeviceId;
30+
MSCCLPP_CUDATHROW(cudaGetDevice(&originalDeviceId)); // Save the current device
31+
MSCCLPP_CUDATHROW(cudaSetDevice(deviceId_));
2632
MSCCLPP_CUDATHROW(cudaStreamCreateWithFlags(&stream_, flags));
33+
MSCCLPP_CUDATHROW(cudaSetDevice(originalDeviceId)); // Restore the original device
2734
}
2835

2936
bool CudaStreamWithFlags::empty() const { return stream_ == nullptr; }
3037

3138
GpuStream::GpuStream(std::shared_ptr<GpuStreamPool> pool, std::shared_ptr<CudaStreamWithFlags> stream)
3239
: pool_(pool), stream_(stream) {}
3340

34-
GpuStream::~GpuStream() { pool_->streams_.push_back(stream_); }
41+
GpuStream::~GpuStream() { pool_->streams_[deviceId()].push_back(stream_); }
3542

3643
GpuStreamPool::GpuStreamPool() {}
3744

3845
GpuStream GpuStreamPool::getStream() {
39-
if (!streams_.empty()) {
40-
auto stream = streams_.back();
41-
streams_.pop_back();
46+
int deviceId;
47+
MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId));
48+
auto& streamVec = streams_[deviceId];
49+
if (!streamVec.empty()) {
50+
auto stream = streamVec.back();
51+
streamVec.pop_back();
4252
return GpuStream(gpuStreamPool(), stream);
4353
}
4454
return GpuStream(gpuStreamPool(), std::make_shared<CudaStreamWithFlags>(cudaStreamNonBlocking));
@@ -47,7 +57,7 @@ GpuStream GpuStreamPool::getStream() {
4757
void GpuStreamPool::clear() { streams_.clear(); }
4858

4959
// A global pool instance
50-
std::shared_ptr<GpuStreamPool> gGpuStreamPool_;
60+
static std::shared_ptr<GpuStreamPool> gGpuStreamPool_;
5161

5262
std::shared_ptr<GpuStreamPool> gpuStreamPool() {
5363
if (!gGpuStreamPool_) {

0 commit comments

Comments
 (0)