Skip to content

Commit c0b9df6

Browse files
authored
[STF] Use cudaStreamGetId instead of manually assigning unique IDs to streams (#7899)
* Stop manually assigning unique IDs to CUDA streams: we have had cu(da)StreamGetId since CUDA 12.0 so use it instead * (un)register_stream is not needed at all * clang-format * remove redundant ctors * Added a paranoid assertion that cuStreamGetId never returns k_no_stream_id
1 parent 85befbd commit c0b9df6

File tree

6 files changed

+52
-133
lines changed

6 files changed

+52
-133
lines changed

cudax/include/cuda/experimental/__stf/internal/async_resources_handle.cuh

Lines changed: 44 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -37,33 +37,59 @@
3737
#include <mutex>
3838
#include <unordered_map>
3939

40+
#include <cuda.h>
41+
4042
namespace cuda::experimental::stf
4143
{
4244
class green_context_helper;
4345

4446
// Needed to set/get affinity
4547
class exec_place;
4648

49+
/** Sentinel for "no stream" / empty slot. Distinct from any value returned by cuStreamGetId. */
50+
inline constexpr unsigned long long k_no_stream_id = static_cast<unsigned long long>(-1);
51+
4752
/**
48-
* @brief A class to store a CUDA stream along with a few information to avoid CUDA queries
53+
* @brief Returns the unique stream ID from the CUDA driver (cuStreamGetId).
54+
* @param stream A valid CUDA stream, or nullptr.
55+
* @return The stream's unique ID, or k_no_stream_id if stream is nullptr.
56+
*/
57+
inline unsigned long long get_stream_id(cudaStream_t stream)
58+
{
59+
unsigned long long id = 0;
60+
cuda_safe_call(cuStreamGetId(reinterpret_cast<CUstream>(stream), &id));
61+
_CCCL_ASSERT(id != k_no_stream_id, "Internal error: cuStreamGetId returned k_no_stream_id");
62+
return id;
63+
}
64+
65+
/**
66+
* @brief A class to store a CUDA stream along with metadata
4967
*
5068
* It contains
5169
* - the stream itself,
52-
* - a unique id (proper to CUDASTF, and only valid for streams in our pool, or equal to -1),
53-
* - the pool associated to the unique ID, when valid
54-
* - the device index in which the stream is
70+
* - the stream's unique ID from the CUDA driver (cuStreamGetId), or k_no_stream_id if no stream,
71+
* - the device index in which the stream resides
5572
*/
5673
struct decorated_stream
5774
{
58-
decorated_stream(cudaStream_t stream = nullptr, ::std::ptrdiff_t id = -1, int dev_id = -1)
75+
decorated_stream() = default;
76+
77+
decorated_stream(cudaStream_t stream, unsigned long long id, int dev_id = -1)
5978
: stream(stream)
6079
, id(id)
6180
, dev_id(dev_id)
6281
{}
6382

83+
/** Construct from stream only; id is from cuStreamGetId, dev_id is -1 (filled lazily when needed). */
84+
explicit decorated_stream(cudaStream_t stream)
85+
: stream(stream)
86+
, id(get_stream_id(stream))
87+
, dev_id(-1)
88+
{}
89+
6490
cudaStream_t stream = nullptr;
65-
// Unique ID (-1 if this is not part of our pool)
66-
::std::ptrdiff_t id = -1;
91+
// Unique ID from cuStreamGetId (k_no_stream_id if no stream)
92+
unsigned long long id = k_no_stream_id;
6793
// Device in which this stream resides
6894
int dev_id = -1;
6995
};
@@ -90,10 +116,10 @@ struct stream_pool
90116
* @brief stream_pool constructor taking a number of slots.
91117
*
92118
* Streams are created lazily only via next(place), which activates the place and calls place.create_stream().
93-
* Slot dev_id is set from the created stream; the pool does not store a device id.
119+
* Slot dev_id and id are set when the stream is created in next().
94120
*/
95121
explicit stream_pool(size_t n)
96-
: payload(n, decorated_stream(nullptr, -1, -1))
122+
: payload(n, decorated_stream(nullptr, k_no_stream_id, -1))
97123
{}
98124

99125
stream_pool(stream_pool&& rhs)
@@ -156,36 +182,6 @@ public:
156182
static constexpr size_t data_pool_size = 4;
157183

158184
private:
159-
/**
160-
* @brief A helper class to maintain a set of available IDs, and attributes IDs
161-
*/
162-
class id_pool
163-
{
164-
public:
165-
~id_pool()
166-
{
167-
assert(released.load() == current.load());
168-
}
169-
170-
::std::ptrdiff_t get_unique_id(size_t cnt = 1)
171-
{
172-
// Use fetch_add to atomically increment current and return the previous value
173-
return current.fetch_add(cnt);
174-
}
175-
176-
void release_unique_id(::std::ptrdiff_t /* id */, size_t cnt = 1)
177-
{
178-
// Use fetch_add to atomically increment released
179-
released.fetch_add(cnt);
180-
}
181-
182-
private:
183-
// next available ID
184-
::std::atomic<::std::ptrdiff_t> current{0};
185-
// Number of IDs released, for bookkeeping
186-
::std::atomic<::std::ptrdiff_t> released{0};
187-
};
188-
189185
/**
190186
* @brief This class implements a matrix to keep track of the previous
191187
* synchronization that occurred between each pair of streams in our pools.
@@ -195,7 +191,7 @@ private:
195191
* ID) is implied by the previous synchronization, so it can be skipped thanks
196192
* to stream-ordering of operations.
197193
*
198-
* This is implemented as a hash table where keys are pairs of IDs.
194+
* Keys are pairs of stream IDs from cuStreamGetId.
199195
*/
200196
class last_event_per_stream
201197
{
@@ -204,10 +200,10 @@ private:
204200
// located on stream "from" to stream "dst" (stream dst waits for the
205201
// event)
206202
// Returned value : boolean indicating if we can skip the synchronization
207-
bool validate_sync_and_update(::std::ptrdiff_t dst, ::std::ptrdiff_t src, int event_id)
203+
bool validate_sync_and_update(unsigned long long dst, unsigned long long src, int event_id)
208204
{
209-
// If either of the streams is not from the pool, do not skip
210-
if (dst == -1 || src == -1)
205+
// If either of the streams has no valid id, do not skip
206+
if (dst == k_no_stream_id || src == k_no_stream_id)
211207
{
212208
return false;
213209
}
@@ -232,10 +228,10 @@ private:
232228
}
233229

234230
private:
235-
// For each pair of unique IDs, we keep the last event id
236-
::std::unordered_map<::std::pair<::std::ptrdiff_t, ::std::ptrdiff_t>,
231+
// For each pair of stream IDs (from cuStreamGetId), we keep the last event id
232+
::std::unordered_map<::std::pair<unsigned long long, unsigned long long>,
237233
int,
238-
cuda::experimental::stf::hash<::std::pair<::std::ptrdiff_t, ::std::ptrdiff_t>>>
234+
cuda::experimental::stf::hash<::std::pair<unsigned long long, unsigned long long>>>
239235
interactions;
240236

241237
::std::mutex mtx;
@@ -295,7 +291,7 @@ private:
295291
for (auto i : each(n))
296292
{
297293
::std::ignore = i;
298-
new_payload.emplace_back(nullptr, ids.get_unique_id(), dev_id);
294+
new_payload.emplace_back(nullptr, k_no_stream_id, dev_id);
299295
}
300296

301297
::std::lock_guard<::std::mutex> locker(p.mtx);
@@ -312,7 +308,6 @@ private:
312308
// Clean up outside the critical section
313309
for (auto& e : goner)
314310
{
315-
ids.release_unique_id(e.id);
316311
if (e.stream)
317312
{
318313
cuda_safe_call(cudaStreamDestroy(e.stream));
@@ -321,9 +316,6 @@ private:
321316
}
322317

323318
public:
324-
// These are constructed and destroyed in reversed order
325-
id_pool ids;
326-
327319
// This memorize what was the last event used to synchronize a pair of streams
328320
last_event_per_stream cached_syncs;
329321

@@ -359,19 +351,7 @@ public:
359351
return pimpl->get_device_stream_pool(dev_id, for_computation);
360352
}
361353

362-
::std::ptrdiff_t get_unique_id(size_t cnt = 1)
363-
{
364-
assert(pimpl);
365-
return pimpl->ids.get_unique_id(cnt);
366-
}
367-
368-
void release_unique_id(::std::ptrdiff_t id, size_t cnt = 1)
369-
{
370-
assert(pimpl);
371-
return pimpl->ids.release_unique_id(id, cnt);
372-
}
373-
374-
bool validate_sync_and_update(::std::ptrdiff_t dst, ::std::ptrdiff_t src, int event_id)
354+
bool validate_sync_and_update(unsigned long long dst, unsigned long long src, int event_id)
375355
{
376356
assert(pimpl);
377357
return pimpl->cached_syncs.validate_sync_and_update(dst, src, event_id);
@@ -446,54 +426,6 @@ public:
446426
}
447427
};
448428

449-
//! @brief Registers a user-provided CUDA stream with asynchronous resources
450-
//!
451-
//! @details This optimization records a CUDA stream in the provided asynchronous resources handle,
452-
//! creating a decorated_stream object that encapsulates:
453-
//! - The original stream handle
454-
//! - A unique identifier for stream tracking
455-
//! - The associated device ID
456-
//!
457-
//! @param[in,out] async_resources Handle to asynchronous resources manager
458-
//! @param[in] user_stream Raw CUDA stream to register. Must be a valid stream.
459-
//!
460-
//! @return decorated_stream Object containing:
461-
//! - Original stream handle
462-
//! - Unique ID from async_resources
463-
//! - Device ID associated with the stream
464-
//!
465-
//! @pre `user_stream` must be a valid CUDA stream created with `cudaStreamCreate` or equivalent
466-
//! @note This registration is an optimization to avoid repeated stream metadata lookups
467-
//! in performance-critical code paths
468-
inline decorated_stream register_stream(async_resources_handle& async_resources, cudaStream_t user_stream)
469-
{
470-
// Get a unique ID
471-
const auto id = async_resources.get_unique_id();
472-
const int dev_id = get_device_from_stream(user_stream);
473-
474-
return decorated_stream(user_stream, id, dev_id);
475-
}
476-
477-
//! @brief Unregisters a decorated CUDA stream from asynchronous resources
478-
//!
479-
//! @details Performs cleanup operations to release resources associated with a previously
480-
//! registered stream. This includes:
481-
//! - Releasing the unique ID back to the resource manager
482-
//! - Invalidating the decorated stream's internal ID
483-
//!
484-
//! @param[in,out] async_resources Handle to asynchronous resources manager
485-
//! @param[in,out] dstream Decorated stream to unregister. Its `id` will be set to -1.
486-
//!
487-
//! @pre `dstream.id` must be valid (≥ 0) before calling this function
488-
//! @post `dstream.id == -1` and associated resources are released
489-
//! @note Should be paired with register_stream() for proper resource management
490-
inline void unregister_stream(async_resources_handle& async_resources, decorated_stream& dstream)
491-
{
492-
async_resources.release_unique_id(dstream.id);
493-
// reset the decorated stream
494-
dstream.id = -1;
495-
}
496-
497429
#ifdef UNITTESTED_FILE
498430
/*
499431
* This test ensures that the async_resources_handle type is default

cudax/include/cuda/experimental/__stf/places/exec/cuda_stream.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ public:
114114
inline exec_place_cuda_stream exec_place::cuda_stream(cudaStream_t stream)
115115
{
116116
int devid = get_device_from_stream(stream);
117-
return exec_place_cuda_stream(decorated_stream(stream, -1, devid));
117+
return exec_place_cuda_stream(decorated_stream(stream, get_stream_id(stream), devid));
118118
}
119119

120120
inline exec_place_cuda_stream exec_place::cuda_stream(const decorated_stream& dstream)

cudax/include/cuda/experimental/__stf/places/places.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,6 +1218,7 @@ inline decorated_stream stream_pool::next(const exec_place& place)
12181218
{
12191219
exec_place_guard guard(place);
12201220
result.stream = place.create_stream();
1221+
result.id = get_stream_id(result.stream);
12211222
result.dev_id = get_device_from_stream(result.stream);
12221223
}
12231224

cudax/include/cuda/experimental/__stf/stream/internal/event_types.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ public:
255255
return dstream;
256256
}
257257

258-
::std::ptrdiff_t get_stream_id() const
258+
unsigned long long get_stream_id() const
259259
{
260260
return dstream.id;
261261
}
@@ -397,7 +397,7 @@ private:
397397
for (const auto& e : prereq_in)
398398
{
399399
cudaStream_t stream;
400-
::std::ptrdiff_t stream_id = -1;
400+
unsigned long long stream_id = 0;
401401
auto se = reserved::handle<stream_and_event, reserved::handle_flags::non_null>(e, reserved::use_static_cast);
402402
stream = se->get_stream();
403403
stream_id = se->get_stream_id();
@@ -415,7 +415,7 @@ private:
415415
if (stream_dev == devid)
416416
{
417417
// fprintf(stderr, "Found matching device %d with stream %p\n", devid, stream);
418-
return decorated_stream(stream, stream_id, devid);
418+
return decorated_stream(stream, stream_id, static_cast<int>(stream_dev));
419419
}
420420
}
421421

cudax/include/cuda/experimental/__stf/stream/stream_task.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ public:
156156
auto se = reserved::handle<stream_and_event>(e, reserved::use_static_cast);
157157
decorated_stream candidate = se->get_decorated_stream();
158158

159-
if (candidate.id != -1)
159+
if (candidate.id != k_no_stream_id)
160160
{
161161
for (const decorated_stream& pool_s : pool)
162162
{

cudax/test/stf/places/cuda_stream_place.cu

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,30 +53,16 @@ int main()
5353
auto lX = ctx.logical_data(X);
5454
auto lY = ctx.logical_data(Y);
5555

56-
/* Compute Y = Y + alpha X */
56+
/* Compute Y = Y + alpha X on the user stream */
5757
auto where = exec_place::cuda_stream(stream);
5858

59-
for (size_t iter = 0; iter < 10; iter++)
59+
for (size_t iter = 0; iter < 20; iter++)
6060
{
6161
ctx.parallel_for(where, lX.shape(), lX.read(), lY.rw())->*[alpha] __device__(size_t i, auto x, auto y) {
6262
y(i) += alpha * x(i);
6363
};
6464
}
6565

66-
/* Associate the CUDA stream with a unique internal ID to speed up synchronizations */
67-
auto rstream = register_stream(ctx.async_resources(), stream);
68-
auto where2 = exec_place::cuda_stream(rstream);
69-
70-
for (size_t iter = 0; iter < 10; iter++)
71-
{
72-
ctx.parallel_for(where2, lX.shape(), lX.read(), lY.rw())->*[alpha] __device__(size_t i, auto x, auto y) {
73-
y(i) += alpha * x(i);
74-
};
75-
}
76-
77-
// Remove the association
78-
unregister_stream(ctx.async_resources(), rstream);
79-
8066
ctx.finalize();
8167

8268
for (size_t i = 0; i < N; i++)

0 commit comments

Comments
 (0)