3737#include < mutex>
3838#include < unordered_map>
3939
40+ #include < cuda.h>
41+
4042namespace cuda ::experimental::stf
4143{
4244class green_context_helper ;
4345
4446// Needed to set/get affinity
4547class exec_place ;
4648
49+ /* * Sentinel for "no stream" / empty slot. Distinct from any value returned by cuStreamGetId. */
50+ inline constexpr unsigned long long k_no_stream_id = static_cast <unsigned long long >(-1 );
51+
4752/* *
48- * @brief A class to store a CUDA stream along with a few information to avoid CUDA queries
53+ * @brief Returns the unique stream ID from the CUDA driver (cuStreamGetId).
54+ * @param stream A valid CUDA stream, or nullptr.
55+ * @return The stream's unique ID, or k_no_stream_id if stream is nullptr.
56+ */
57+ inline unsigned long long get_stream_id (cudaStream_t stream)
58+ {
59+ unsigned long long id = 0 ;
60+ cuda_safe_call (cuStreamGetId (reinterpret_cast <CUstream>(stream), &id));
61+ _CCCL_ASSERT (id != k_no_stream_id, " Internal error: cuStreamGetId returned k_no_stream_id" );
62+ return id;
63+ }
64+
65+ /* *
66+ * @brief A class to store a CUDA stream along with metadata
4967 *
5068 * It contains
5169 * - the stream itself,
52- * - a unique id (proper to CUDASTF, and only valid for streams in our pool, or equal to -1),
53- * - the pool associated to the unique ID, when valid
54- * - the device index in which the stream is
70+ * - the stream's unique ID from the CUDA driver (cuStreamGetId), or k_no_stream_id if no stream,
71+ * - the device index in which the stream resides
5572 */
5673struct decorated_stream
5774{
58- decorated_stream (cudaStream_t stream = nullptr , ::std::ptrdiff_t id = -1 , int dev_id = -1 )
75+ decorated_stream () = default ;
76+
77+ decorated_stream (cudaStream_t stream, unsigned long long id, int dev_id = -1 )
5978 : stream(stream)
6079 , id(id)
6180 , dev_id(dev_id)
6281 {}
6382
83+ /* * Construct from stream only; id is from cuStreamGetId, dev_id is -1 (filled lazily when needed). */
84+ explicit decorated_stream (cudaStream_t stream)
85+ : stream(stream)
86+ , id(get_stream_id(stream))
87+ , dev_id(-1 )
88+ {}
89+
6490 cudaStream_t stream = nullptr ;
65- // Unique ID (-1 if this is not part of our pool )
66- ::std:: ptrdiff_t id = - 1 ;
91+ // Unique ID from cuStreamGetId (k_no_stream_id if no stream )
92+ unsigned long long id = k_no_stream_id ;
6793 // Device in which this stream resides
6894 int dev_id = -1 ;
6995};
@@ -90,10 +116,10 @@ struct stream_pool
90116 * @brief stream_pool constructor taking a number of slots.
91117 *
92118 * Streams are created lazily only via next(place), which activates the place and calls place.create_stream().
93- * Slot dev_id is set from the created stream; the pool does not store a device id .
119+ * Slot dev_id and id are set when the stream is created in next() .
94120 */
95121 explicit stream_pool (size_t n)
96- : payload(n, decorated_stream(nullptr , - 1 , -1 ))
122+ : payload(n, decorated_stream(nullptr , k_no_stream_id , -1 ))
97123 {}
98124
99125 stream_pool (stream_pool&& rhs)
@@ -156,36 +182,6 @@ public:
156182 static constexpr size_t data_pool_size = 4 ;
157183
158184private:
159- /* *
160- * @brief A helper class to maintain a set of available IDs, and attributes IDs
161- */
162- class id_pool
163- {
164- public:
165- ~id_pool ()
166- {
167- assert (released.load () == current.load ());
168- }
169-
170- ::std::ptrdiff_t get_unique_id (size_t cnt = 1 )
171- {
172- // Use fetch_add to atomically increment current and return the previous value
173- return current.fetch_add (cnt);
174- }
175-
176- void release_unique_id (::std::ptrdiff_t /* id */ , size_t cnt = 1 )
177- {
178- // Use fetch_add to atomically increment released
179- released.fetch_add (cnt);
180- }
181-
182- private:
183- // next available ID
184- ::std::atomic<::std::ptrdiff_t > current{0 };
185- // Number of IDs released, for bookkeeping
186- ::std::atomic<::std::ptrdiff_t > released{0 };
187- };
188-
189185 /* *
190186 * @brief This class implements a matrix to keep track of the previous
191187 * synchronization that occurred between each pair of streams in our pools.
@@ -195,7 +191,7 @@ private:
195191 * ID) is implied by the previous synchronization, so it can be skipped thanks
196192 * to stream-ordering of operations.
197193 *
198- * This is implemented as a hash table where keys are pairs of IDs.
194+ * Keys are pairs of stream IDs from cuStreamGetId .
199195 */
200196 class last_event_per_stream
201197 {
@@ -204,10 +200,10 @@ private:
204200 // located on stream "from" to stream "dst" (stream dst waits for the
205201 // event)
206202 // Returned value : boolean indicating if we can skip the synchronization
207- bool validate_sync_and_update (::std:: ptrdiff_t dst, ::std:: ptrdiff_t src, int event_id)
203+ bool validate_sync_and_update (unsigned long long dst, unsigned long long src, int event_id)
208204 {
209- // If either of the streams is not from the pool , do not skip
210- if (dst == - 1 || src == - 1 )
205+ // If either of the streams has no valid id , do not skip
206+ if (dst == k_no_stream_id || src == k_no_stream_id )
211207 {
212208 return false ;
213209 }
@@ -232,10 +228,10 @@ private:
232228 }
233229
234230 private:
235- // For each pair of unique IDs, we keep the last event id
236- ::std::unordered_map<::std::pair<::std:: ptrdiff_t , ::std:: ptrdiff_t >,
231+ // For each pair of stream IDs (from cuStreamGetId) , we keep the last event id
232+ ::std::unordered_map<::std::pair<unsigned long long , unsigned long long >,
237233 int ,
238- cuda::experimental::stf::hash<::std::pair<::std:: ptrdiff_t , ::std:: ptrdiff_t >>>
234+ cuda::experimental::stf::hash<::std::pair<unsigned long long , unsigned long long >>>
239235 interactions;
240236
241237 ::std::mutex mtx;
@@ -295,7 +291,7 @@ private:
295291 for (auto i : each (n))
296292 {
297293 ::std::ignore = i;
298- new_payload.emplace_back (nullptr , ids. get_unique_id () , dev_id);
294+ new_payload.emplace_back (nullptr , k_no_stream_id , dev_id);
299295 }
300296
301297 ::std::lock_guard<::std::mutex> locker (p.mtx );
@@ -312,7 +308,6 @@ private:
312308 // Clean up outside the critical section
313309 for (auto & e : goner)
314310 {
315- ids.release_unique_id (e.id );
316311 if (e.stream )
317312 {
318313 cuda_safe_call (cudaStreamDestroy (e.stream ));
@@ -321,9 +316,6 @@ private:
321316 }
322317
323318 public:
324- // These are constructed and destroyed in reversed order
325- id_pool ids;
326-
327319 // This memorize what was the last event used to synchronize a pair of streams
328320 last_event_per_stream cached_syncs;
329321
@@ -359,19 +351,7 @@ public:
359351 return pimpl->get_device_stream_pool (dev_id, for_computation);
360352 }
361353
362- ::std::ptrdiff_t get_unique_id (size_t cnt = 1 )
363- {
364- assert (pimpl);
365- return pimpl->ids .get_unique_id (cnt);
366- }
367-
368- void release_unique_id (::std::ptrdiff_t id, size_t cnt = 1 )
369- {
370- assert (pimpl);
371- return pimpl->ids .release_unique_id (id, cnt);
372- }
373-
374- bool validate_sync_and_update (::std::ptrdiff_t dst, ::std::ptrdiff_t src, int event_id)
354+ bool validate_sync_and_update (unsigned long long dst, unsigned long long src, int event_id)
375355 {
376356 assert (pimpl);
377357 return pimpl->cached_syncs .validate_sync_and_update (dst, src, event_id);
@@ -446,54 +426,6 @@ public:
446426 }
447427};
448428
449- // ! @brief Registers a user-provided CUDA stream with asynchronous resources
450- // !
451- // ! @details This optimization records a CUDA stream in the provided asynchronous resources handle,
452- // ! creating a decorated_stream object that encapsulates:
453- // ! - The original stream handle
454- // ! - A unique identifier for stream tracking
455- // ! - The associated device ID
456- // !
457- // ! @param[in,out] async_resources Handle to asynchronous resources manager
458- // ! @param[in] user_stream Raw CUDA stream to register. Must be a valid stream.
459- // !
460- // ! @return decorated_stream Object containing:
461- // ! - Original stream handle
462- // ! - Unique ID from async_resources
463- // ! - Device ID associated with the stream
464- // !
465- // ! @pre `user_stream` must be a valid CUDA stream created with `cudaStreamCreate` or equivalent
466- // ! @note This registration is an optimization to avoid repeated stream metadata lookups
467- // ! in performance-critical code paths
468- inline decorated_stream register_stream (async_resources_handle& async_resources, cudaStream_t user_stream)
469- {
470- // Get a unique ID
471- const auto id = async_resources.get_unique_id ();
472- const int dev_id = get_device_from_stream (user_stream);
473-
474- return decorated_stream (user_stream, id, dev_id);
475- }
476-
477- // ! @brief Unregisters a decorated CUDA stream from asynchronous resources
478- // !
479- // ! @details Performs cleanup operations to release resources associated with a previously
480- // ! registered stream. This includes:
481- // ! - Releasing the unique ID back to the resource manager
482- // ! - Invalidating the decorated stream's internal ID
483- // !
484- // ! @param[in,out] async_resources Handle to asynchronous resources manager
485- // ! @param[in,out] dstream Decorated stream to unregister. Its `id` will be set to -1.
486- // !
487- // ! @pre `dstream.id` must be valid (≥ 0) before calling this function
488- // ! @post `dstream.id == -1` and associated resources are released
489- // ! @note Should be paired with register_stream() for proper resource management
490- inline void unregister_stream (async_resources_handle& async_resources, decorated_stream& dstream)
491- {
492- async_resources.release_unique_id (dstream.id );
493- // reset the decorated stream
494- dstream.id = -1 ;
495- }
496-
497429#ifdef UNITTESTED_FILE
498430/*
499431 * This test ensures that the async_resources_handle type is default
0 commit comments