Skip to content

Commit 945976c

Browse files
authored
[UR][Offload] Event waiting (#19594)
Implement urEventsWait[WithBarrier] and respect the waitlist of enqueue functions.
1 parent 74dbc1e commit 945976c

File tree

3 files changed

+161
-54
lines changed

3 files changed

+161
-54
lines changed

unified-runtime/source/adapters/offload/enqueue.cpp

Lines changed: 139 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,130 @@
1919
#include "queue.hpp"
2020
#include "ur2offload.hpp"
2121

22+
namespace {
23+
ol_result_t waitOnEvents(ol_queue_handle_t Queue,
24+
const ur_event_handle_t *UrEvents, size_t NumEvents) {
25+
if (NumEvents) {
26+
std::vector<ol_event_handle_t> OlEvents;
27+
OlEvents.reserve(NumEvents);
28+
for (size_t I = 0; I < NumEvents; I++) {
29+
OlEvents.push_back(UrEvents[I]->OffloadEvent);
30+
}
31+
32+
return olWaitEvents(Queue, OlEvents.data(), NumEvents);
33+
}
34+
return OL_SUCCESS;
35+
}
36+
37+
ol_result_t makeEvent(ur_command_t Type, ol_queue_handle_t OlQueue,
38+
ur_queue_handle_t UrQueue, ur_event_handle_t *UrEvent) {
39+
if (UrEvent) {
40+
auto *Event = new ur_event_handle_t_(Type, UrQueue);
41+
if (auto Res = olCreateEvent(OlQueue, &Event->OffloadEvent)) {
42+
delete Event;
43+
return Res;
44+
};
45+
*UrEvent = Event;
46+
}
47+
return OL_SUCCESS;
48+
}
49+
50+
template <bool Barrier>
51+
ur_result_t doWait(ur_queue_handle_t hQueue, uint32_t numEventsInWaitList,
52+
const ur_event_handle_t *phEventWaitList,
53+
ur_event_handle_t *phEvent) {
54+
std::lock_guard<std::mutex> Lock(hQueue->OooMutex);
55+
constexpr ur_command_t TYPE =
56+
Barrier ? UR_COMMAND_EVENTS_WAIT_WITH_BARRIER : UR_COMMAND_EVENTS_WAIT;
57+
ol_queue_handle_t TargetQueue;
58+
if (!numEventsInWaitList && hQueue->isInOrder()) {
59+
// In order queue so all work is done in submission order, so it's a
60+
// no-op
61+
if (phEvent) {
62+
OL_RETURN_ON_ERR(hQueue->nextQueueNoLock(TargetQueue));
63+
OL_RETURN_ON_ERR(makeEvent(TYPE, TargetQueue, hQueue, phEvent));
64+
}
65+
return UR_RESULT_SUCCESS;
66+
}
67+
OL_RETURN_ON_ERR(hQueue->nextQueueNoLock(TargetQueue));
68+
69+
if (!numEventsInWaitList) {
70+
// "If the event list is empty, it waits for all previously enqueued
71+
// commands to complete."
72+
73+
// Create events on each active queue for an arbitrary thread to block on
74+
// TODO: Can we efficiently check if each thread is "finished" rather than
75+
// creating an event?
76+
std::vector<ol_event_handle_t> OffloadHandles{};
77+
for (auto *Q : hQueue->OffloadQueues) {
78+
if (Q == nullptr) {
79+
break;
80+
}
81+
if (Q == TargetQueue) {
82+
continue;
83+
}
84+
OL_RETURN_ON_ERR(olCreateEvent(Q, &OffloadHandles.emplace_back()));
85+
}
86+
OL_RETURN_ON_ERR(olWaitEvents(TargetQueue, OffloadHandles.data(),
87+
OffloadHandles.size()));
88+
} else {
89+
OL_RETURN_ON_ERR(
90+
waitOnEvents(TargetQueue, phEventWaitList, numEventsInWaitList));
91+
}
92+
93+
OL_RETURN_ON_ERR(makeEvent(TYPE, TargetQueue, hQueue, phEvent));
94+
95+
if constexpr (Barrier) {
96+
ol_event_handle_t BarrierEvent;
97+
if (phEvent) {
98+
BarrierEvent = (*phEvent)->OffloadEvent;
99+
} else {
100+
OL_RETURN_ON_ERR(olCreateEvent(TargetQueue, &BarrierEvent));
101+
}
102+
103+
// Ensure any newly created work waits on this barrier
104+
if (hQueue->Barrier) {
105+
OL_RETURN_ON_ERR(olDestroyEvent(hQueue->Barrier));
106+
}
107+
hQueue->Barrier = BarrierEvent;
108+
109+
// Block all existing threads on the barrier
110+
for (auto *Q : hQueue->OffloadQueues) {
111+
if (Q == nullptr) {
112+
break;
113+
}
114+
if (Q == TargetQueue) {
115+
continue;
116+
}
117+
OL_RETURN_ON_ERR(olWaitEvents(Q, &BarrierEvent, 1));
118+
}
119+
}
120+
121+
return UR_RESULT_SUCCESS;
122+
}
123+
} // namespace
124+
125+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWait(
126+
ur_queue_handle_t hQueue, uint32_t numEventsInWaitList,
127+
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
128+
return doWait<false>(hQueue, numEventsInWaitList, phEventWaitList, phEvent);
129+
}
130+
131+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier(
132+
ur_queue_handle_t hQueue, uint32_t numEventsInWaitList,
133+
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
134+
return doWait<true>(hQueue, numEventsInWaitList, phEventWaitList, phEvent);
135+
}
136+
22137
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
23138
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
24139
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
25140
const size_t *pLocalWorkSize, uint32_t, const ur_kernel_launch_property_t *,
26141
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
27142
ur_event_handle_t *phEvent) {
28-
// Ignore wait list for now
29-
(void)numEventsInWaitList;
30-
(void)phEventWaitList;
31-
//
143+
ol_queue_handle_t Queue;
144+
OL_RETURN_ON_ERR(hQueue->nextQueue(Queue));
145+
OL_RETURN_ON_ERR(waitOnEvents(Queue, phEventWaitList, numEventsInWaitList));
32146

33147
(void)pGlobalWorkOffset;
34148

@@ -67,20 +181,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
67181
LaunchArgs.GroupSize.z = GroupSize[2];
68182
LaunchArgs.DynSharedMemory = 0;
69183

70-
ol_queue_handle_t Queue;
71-
OL_RETURN_ON_ERR(hQueue->nextQueue(Queue));
72184
OL_RETURN_ON_ERR(olLaunchKernel(
73185
Queue, hQueue->OffloadDevice, hKernel->OffloadKernel,
74186
hKernel->Args.getStorage(), hKernel->Args.getStorageSize(), &LaunchArgs));
75187

76-
if (phEvent) {
77-
auto *Event = new ur_event_handle_t_(UR_COMMAND_KERNEL_LAUNCH, hQueue);
78-
if (auto Res = olCreateEvent(Queue, &Event->OffloadEvent)) {
79-
delete Event;
80-
return offloadResultToUR(Res);
81-
};
82-
*phEvent = Event;
83-
}
188+
OL_RETURN_ON_ERR(makeEvent(UR_COMMAND_KERNEL_LAUNCH, Queue, hQueue, phEvent));
84189
return UR_RESULT_SUCCESS;
85190
}
86191

@@ -103,10 +208,9 @@ ur_result_t doMemcpy(ur_command_t Command, ur_queue_handle_t hQueue,
103208
size_t size, bool blocking, uint32_t numEventsInWaitList,
104209
const ur_event_handle_t *phEventWaitList,
105210
ur_event_handle_t *phEvent) {
106-
// Ignore wait list for now
107-
(void)numEventsInWaitList;
108-
(void)phEventWaitList;
109-
//
211+
ol_queue_handle_t Queue;
212+
OL_RETURN_ON_ERR(hQueue->nextQueue(Queue));
213+
OL_RETURN_ON_ERR(waitOnEvents(Queue, phEventWaitList, numEventsInWaitList));
110214

111215
if (blocking) {
112216
OL_RETURN_ON_ERR(
@@ -117,8 +221,6 @@ ur_result_t doMemcpy(ur_command_t Command, ur_queue_handle_t hQueue,
117221
return UR_RESULT_SUCCESS;
118222
}
119223

120-
ol_queue_handle_t Queue;
121-
OL_RETURN_ON_ERR(hQueue->nextQueue(Queue));
122224
OL_RETURN_ON_ERR(
123225
olMemcpy(Queue, DestPtr, DestDevice, SrcPtr, SrcDevice, size));
124226
if (phEvent) {
@@ -192,17 +294,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite(
192294
numEventsInWaitList, phEventWaitList, phEvent);
193295
}
194296

195-
ur_result_t enqueueNoOp(ur_command_t Type, ur_queue_handle_t hQueue,
196-
ur_event_handle_t *phEvent) {
197-
// This path is a no-op, but we can't output a real event because
198-
// Offload doesn't currently support creating arbitrary events, and we
199-
// don't know the last real event in the queue. Instead we just have to
200-
// wait on the whole queue and then return an empty (implicitly
201-
// finished) event.
202-
*phEvent = ur_event_handle_t_::createEmptyEvent(Type, hQueue);
203-
return urQueueFinish(hQueue);
204-
}
205-
206297
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
207298
ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingMap,
208299
ur_map_flags_t mapFlags, size_t offset, size_t size,
@@ -226,15 +317,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
226317
Result = urEnqueueMemBufferRead(hQueue, hBuffer, blockingMap, offset, size,
227318
MapPtr, numEventsInWaitList,
228319
phEventWaitList, phEvent);
229-
} else {
230-
if (IsPinned) {
231-
// TODO: Ignore the event waits list for now. When urEnqueueEventsWait is
232-
// implemented we can call it on the wait list.
233-
}
234-
235-
if (phEvent) {
236-
enqueueNoOp(UR_COMMAND_MEM_BUFFER_MAP, hQueue, phEvent);
320+
} else if (numEventsInWaitList || phEvent) {
321+
ol_queue_handle_t Queue;
322+
OL_RETURN_ON_ERR(hQueue->nextQueue(Queue));
323+
if ((!hQueue->isInOrder() && phEvent) || hQueue->isInOrder()) {
324+
// Out-of-order queues running no-op work only have side effects if there
325+
// is an output event
326+
waitOnEvents(Queue, phEventWaitList, numEventsInWaitList);
237327
}
328+
OL_RETURN_ON_ERR(
329+
makeEvent(UR_COMMAND_MEM_BUFFER_MAP, Queue, hQueue, phEvent));
238330
}
239331
*ppRetMap = MapPtr;
240332

@@ -260,15 +352,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
260352
Result = urEnqueueMemBufferWrite(
261353
hQueue, hMem, true, Map->MapOffset, Map->MapSize, pMappedPtr,
262354
numEventsInWaitList, phEventWaitList, phEvent);
263-
} else {
264-
if (IsPinned) {
265-
// TODO: Ignore the event waits list for now. When urEnqueueEventsWait is
266-
// implemented we can call it on the wait list.
267-
}
268-
269-
if (phEvent) {
270-
enqueueNoOp(UR_COMMAND_MEM_UNMAP, hQueue, phEvent);
355+
} else if (numEventsInWaitList || phEvent) {
356+
ol_queue_handle_t Queue;
357+
OL_RETURN_ON_ERR(hQueue->nextQueue(Queue));
358+
if ((!hQueue->isInOrder() && phEvent) || hQueue->isInOrder()) {
359+
// Out-of-order queues running no-op work only have side effects if there
360+
// is an output event
361+
waitOnEvents(Queue, phEventWaitList, numEventsInWaitList);
271362
}
363+
OL_RETURN_ON_ERR(makeEvent(UR_COMMAND_MEM_UNMAP, Queue, hQueue, phEvent));
272364
}
273365
BufferImpl.unmap(pMappedPtr);
274366

unified-runtime/source/adapters/offload/queue.hpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ struct ur_queue_handle_t_ : RefCounted {
2323
: OffloadQueues((Flags & UR_QUEUE_FLAG_OUT_OF_ORDER_EXEC_MODE_ENABLE)
2424
? 1
2525
: OOO_QUEUE_POOL_SIZE),
26-
QueueOffset(0), OffloadDevice(Device), UrContext(UrContext),
27-
Flags(Flags) {}
26+
QueueOffset(0), Barrier(nullptr), OffloadDevice(Device),
27+
UrContext(UrContext), Flags(Flags) {}
2828

2929
// In-order queues only have one element here, while out of order queues have
3030
// a bank of queues to use. We rotate through them round robin instead of
@@ -35,22 +35,37 @@ struct ur_queue_handle_t_ : RefCounted {
3535
// `stream_queue_t`. In the future, if we want more performance or it
3636
// simplifies the implementation of a feature, we can consider using it.
3737
std::vector<ol_queue_handle_t> OffloadQueues;
38+
// Mutex guarding the offset and barrier for out of order queues
39+
std::mutex OooMutex;
3840
size_t QueueOffset;
41+
ol_event_handle_t Barrier;
3942
ol_device_handle_t OffloadDevice;
4043
ur_context_handle_t UrContext;
4144
ur_queue_flags_t Flags;
4245

43-
ol_result_t nextQueue(ol_queue_handle_t &Handle) {
44-
auto &Slot = OffloadQueues[QueueOffset++];
45-
QueueOffset %= OffloadQueues.size();
46+
bool isInOrder() const { return OffloadQueues.size() == 1; }
47+
48+
ol_result_t nextQueueNoLock(ol_queue_handle_t &Handle) {
49+
auto &Slot = OffloadQueues[(QueueOffset++) % OffloadQueues.size()];
4650

4751
if (!Slot) {
4852
if (auto Res = olCreateQueue(OffloadDevice, &Slot)) {
4953
return Res;
5054
}
55+
56+
if (auto Event = Barrier) {
57+
if (auto Res = olWaitEvents(Slot, &Event, 1)) {
58+
return Res;
59+
}
60+
}
5161
}
5262

5363
Handle = Slot;
5464
return nullptr;
5565
}
66+
67+
ol_result_t nextQueue(ol_queue_handle_t &Handle) {
68+
std::lock_guard<std::mutex> Lock(OooMutex);
69+
return nextQueueNoLock(Handle);
70+
}
5671
};

unified-runtime/source/adapters/offload/ur_interface_loader.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable(
170170
}
171171
pDdiTable->pfnDeviceGlobalVariableRead = urEnqueueDeviceGlobalVariableRead;
172172
pDdiTable->pfnDeviceGlobalVariableWrite = urEnqueueDeviceGlobalVariableWrite;
173-
pDdiTable->pfnEventsWait = nullptr;
174-
pDdiTable->pfnEventsWaitWithBarrier = nullptr;
173+
pDdiTable->pfnEventsWait = urEnqueueEventsWait;
174+
pDdiTable->pfnEventsWaitWithBarrier = urEnqueueEventsWaitWithBarrier;
175175
pDdiTable->pfnKernelLaunch = urEnqueueKernelLaunch;
176176
pDdiTable->pfnMemBufferCopy = nullptr;
177177
pDdiTable->pfnMemBufferCopyRect = nullptr;

0 commit comments

Comments
 (0)