@@ -107,8 +107,6 @@ struct stream_queue_t {
107107 // get_next_compute/transfer_stream() functions return streams from
108108 // appropriate pools in round-robin fashion
109109 native_type getNextComputeStream (uint32_t *StreamToken = nullptr ) {
110- if (getThreadLocalStream () != native_type{0 })
111- return getThreadLocalStream ();
112110 uint32_t StreamI;
113111 uint32_t Token;
114112 while (true ) {
@@ -150,8 +148,6 @@ struct stream_queue_t {
150148 const ur_event_handle_t *EventWaitList,
151149 ur_stream_guard &Guard,
152150 uint32_t *StreamToken = nullptr ) {
153- if (getThreadLocalStream () != native_type{0 })
154- return getThreadLocalStream ();
155151 for (uint32_t i = 0 ; i < NumEventsInWaitList; i++) {
156152 uint32_t Token = getEventComputeStreamToken (EventWaitList[i]);
157153 if (getEventQueue (EventWaitList[i]) == this && canReuseStream (Token)) {
@@ -175,15 +171,7 @@ struct stream_queue_t {
175171 return getNextComputeStream (StreamToken);
176172 }
177173
178- // Thread local stream will be used if ScopedStream is active
179- static native_type &getThreadLocalStream () {
180- static thread_local native_type stream{0 };
181- return stream;
182- }
183-
184174 native_type getNextTransferStream () {
185- if (getThreadLocalStream () != native_type{0 })
186- return getThreadLocalStream ();
187175 if (TransferStreams.empty ()) { // for example in in-order queue
188176 return getNextComputeStream ();
189177 }
@@ -354,4 +342,34 @@ struct stream_queue_t {
354342 uint32_t getNextEventId () noexcept { return ++EventCount; }
355343
356344 bool backendHasOwnership () const noexcept { return HasOwnership; }
345+
346+ // Interop handling, for regular interop we return the next compute stream,
347+ // for native commands we use the interop_guard and return a thread local
348+ // stream. Native commands require to only have one in-order stream to work.
349+ native_type getInteropStream () {
350+ if (getThreadLocalStream () != native_type{0 })
351+ return getThreadLocalStream ();
352+
353+ return getNextComputeStream ();
354+ }
355+
356+ static native_type &getThreadLocalStream () {
357+ static thread_local native_type stream{0 };
358+ return stream;
359+ }
360+
361+ class interop_guard {
362+ stream_queue_t *q;
363+
364+ public:
365+ interop_guard (stream_queue_t *q, uint32_t NumEventsInWaitList,
366+ const ur_event_handle_t *EventWaitList)
367+ : q{q} {
368+ ur_stream_guard Guard;
369+ q->getThreadLocalStream () =
370+ q->getNextComputeStream (NumEventsInWaitList, EventWaitList, Guard);
371+ }
372+ native_type getStream () { return q->getThreadLocalStream (); }
373+ ~interop_guard () { q->getThreadLocalStream () = native_type{0 }; }
374+ };
357375};
0 commit comments