@@ -13,7 +13,10 @@ AvoidCudaGraphCaptureGuard::AvoidCudaGraphCaptureGuard() : mode_(cudaStreamCaptu
13
13
14
14
AvoidCudaGraphCaptureGuard::~AvoidCudaGraphCaptureGuard () { (void )cudaThreadExchangeStreamCaptureMode (&mode_); }
15
15
16
+ CudaStreamWithFlags::CudaStreamWithFlags () : stream_(nullptr ) { MSCCLPP_CUDATHROW (cudaGetDevice (&deviceId_)); }
17
+
16
18
CudaStreamWithFlags::CudaStreamWithFlags (unsigned int flags) {
19
+ MSCCLPP_CUDATHROW (cudaGetDevice (&deviceId_));
17
20
MSCCLPP_CUDATHROW (cudaStreamCreateWithFlags (&stream_, flags));
18
21
}
19
22
@@ -23,22 +26,29 @@ CudaStreamWithFlags::~CudaStreamWithFlags() {
23
26
24
27
void CudaStreamWithFlags::set (unsigned int flags) {
25
28
if (!empty ()) throw Error (" CudaStreamWithFlags already set" , ErrorCode::InvalidUsage);
29
+ int originalDeviceId;
30
+ MSCCLPP_CUDATHROW (cudaGetDevice (&originalDeviceId)); // Save the current device
31
+ MSCCLPP_CUDATHROW (cudaSetDevice (deviceId_));
26
32
MSCCLPP_CUDATHROW (cudaStreamCreateWithFlags (&stream_, flags));
33
+ MSCCLPP_CUDATHROW (cudaSetDevice (originalDeviceId)); // Restore the original device
27
34
}
28
35
29
36
bool CudaStreamWithFlags::empty () const { return stream_ == nullptr ; }
30
37
31
38
GpuStream::GpuStream (std::shared_ptr<GpuStreamPool> pool, std::shared_ptr<CudaStreamWithFlags> stream)
32
39
: pool_(pool), stream_(stream) {}
33
40
34
- GpuStream::~GpuStream () { pool_->streams_ .push_back (stream_); }
41
+ GpuStream::~GpuStream () { pool_->streams_ [ deviceId ()] .push_back (stream_); }
35
42
36
43
GpuStreamPool::GpuStreamPool () {}
37
44
38
45
GpuStream GpuStreamPool::getStream () {
39
- if (!streams_.empty ()) {
40
- auto stream = streams_.back ();
41
- streams_.pop_back ();
46
+ int deviceId;
47
+ MSCCLPP_CUDATHROW (cudaGetDevice (&deviceId));
48
+ auto & streamVec = streams_[deviceId];
49
+ if (!streamVec.empty ()) {
50
+ auto stream = streamVec.back ();
51
+ streamVec.pop_back ();
42
52
return GpuStream (gpuStreamPool (), stream);
43
53
}
44
54
return GpuStream (gpuStreamPool (), std::make_shared<CudaStreamWithFlags>(cudaStreamNonBlocking));
@@ -47,7 +57,7 @@ GpuStream GpuStreamPool::getStream() {
47
57
void GpuStreamPool::clear () { streams_.clear (); }
48
58
49
59
// A global pool instance
50
- std::shared_ptr<GpuStreamPool> gGpuStreamPool_ ;
60
+ static std::shared_ptr<GpuStreamPool> gGpuStreamPool_ ;
51
61
52
62
std::shared_ptr<GpuStreamPool> gpuStreamPool () {
53
63
if (!gGpuStreamPool_ ) {
0 commit comments