diff --git a/unified-runtime/source/adapters/offload/enqueue.cpp b/unified-runtime/source/adapters/offload/enqueue.cpp index b1a1edac522b2..6b9a013538bb4 100644 --- a/unified-runtime/source/adapters/offload/enqueue.cpp +++ b/unified-runtime/source/adapters/offload/enqueue.cpp @@ -19,16 +19,130 @@ #include "queue.hpp" #include "ur2offload.hpp" +namespace { +ol_result_t waitOnEvents(ol_queue_handle_t Queue, + const ur_event_handle_t *UrEvents, size_t NumEvents) { + if (NumEvents) { + std::vector OlEvents; + OlEvents.reserve(NumEvents); + for (size_t I = 0; I < NumEvents; I++) { + OlEvents.push_back(UrEvents[I]->OffloadEvent); + } + + return olWaitEvents(Queue, OlEvents.data(), NumEvents); + } + return OL_SUCCESS; +} + +ol_result_t makeEvent(ur_command_t Type, ol_queue_handle_t OlQueue, + ur_queue_handle_t UrQueue, ur_event_handle_t *UrEvent) { + if (UrEvent) { + auto *Event = new ur_event_handle_t_(Type, UrQueue); + if (auto Res = olCreateEvent(OlQueue, &Event->OffloadEvent)) { + delete Event; + return Res; + }; + *UrEvent = Event; + } + return OL_SUCCESS; +} + +template +ur_result_t doWait(ur_queue_handle_t hQueue, uint32_t numEventsInWaitList, + const ur_event_handle_t *phEventWaitList, + ur_event_handle_t *phEvent) { + std::lock_guard Lock(hQueue->OooMutex); + constexpr ur_command_t TYPE = + Barrier ? UR_COMMAND_EVENTS_WAIT_WITH_BARRIER : UR_COMMAND_EVENTS_WAIT; + ol_queue_handle_t TargetQueue; + if (!numEventsInWaitList && hQueue->isInOrder()) { + // In order queue so all work is done in submission order, so it's a + // no-op + if (phEvent) { + OL_RETURN_ON_ERR(hQueue->nextQueueNoLock(TargetQueue)); + OL_RETURN_ON_ERR(makeEvent(TYPE, TargetQueue, hQueue, phEvent)); + } + return UR_RESULT_SUCCESS; + } + OL_RETURN_ON_ERR(hQueue->nextQueueNoLock(TargetQueue)); + + if (!numEventsInWaitList) { + // "If the event list is empty, it waits for all previously enqueued + // commands to complete." + + // Create events on each active queue for an arbitrary thread to block on + // TODO: Can we efficiently check if each thread is "finished" rather than + // creating an event? + std::vector OffloadHandles{}; + for (auto *Q : hQueue->OffloadQueues) { + if (Q == nullptr) { + break; + } + if (Q == TargetQueue) { + continue; + } + OL_RETURN_ON_ERR(olCreateEvent(Q, &OffloadHandles.emplace_back())); + } + OL_RETURN_ON_ERR(olWaitEvents(TargetQueue, OffloadHandles.data(), + OffloadHandles.size())); + } else { + OL_RETURN_ON_ERR( + waitOnEvents(TargetQueue, phEventWaitList, numEventsInWaitList)); + } + + OL_RETURN_ON_ERR(makeEvent(TYPE, TargetQueue, hQueue, phEvent)); + + if constexpr (Barrier) { + ol_event_handle_t BarrierEvent; + if (phEvent) { + BarrierEvent = (*phEvent)->OffloadEvent; + } else { + OL_RETURN_ON_ERR(olCreateEvent(TargetQueue, &BarrierEvent)); + } + + // Ensure any newly created work waits on this barrier + if (hQueue->Barrier) { + OL_RETURN_ON_ERR(olDestroyEvent(hQueue->Barrier)); + } + hQueue->Barrier = BarrierEvent; + + // Block all existing threads on the barrier + for (auto *Q : hQueue->OffloadQueues) { + if (Q == nullptr) { + break; + } + if (Q == TargetQueue) { + continue; + } + OL_RETURN_ON_ERR(olWaitEvents(Q, &BarrierEvent, 1)); + } + } + + return UR_RESULT_SUCCESS; +} +} // namespace + +UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWait( + ur_queue_handle_t hQueue, uint32_t numEventsInWaitList, + const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { + return doWait(hQueue, numEventsInWaitList, phEventWaitList, phEvent); +} + +UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier( + ur_queue_handle_t hQueue, uint32_t numEventsInWaitList, + const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { + return doWait(hQueue, numEventsInWaitList, phEventWaitList, phEvent); +} + UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize, const size_t *pLocalWorkSize, uint32_t, const ur_kernel_launch_property_t *, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - // Ignore wait list for now - (void)numEventsInWaitList; - (void)phEventWaitList; - // + ol_queue_handle_t Queue; + OL_RETURN_ON_ERR(hQueue->nextQueue(Queue)); + OL_RETURN_ON_ERR(waitOnEvents(Queue, phEventWaitList, numEventsInWaitList)); (void)pGlobalWorkOffset; @@ -67,20 +181,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( LaunchArgs.GroupSize.z = GroupSize[2]; LaunchArgs.DynSharedMemory = 0; - ol_queue_handle_t Queue; - OL_RETURN_ON_ERR(hQueue->nextQueue(Queue)); OL_RETURN_ON_ERR(olLaunchKernel( Queue, hQueue->OffloadDevice, hKernel->OffloadKernel, hKernel->Args.getStorage(), hKernel->Args.getStorageSize(), &LaunchArgs)); - if (phEvent) { - auto *Event = new ur_event_handle_t_(UR_COMMAND_KERNEL_LAUNCH, hQueue); - if (auto Res = olCreateEvent(Queue, &Event->OffloadEvent)) { - delete Event; - return offloadResultToUR(Res); - }; - *phEvent = Event; - } + OL_RETURN_ON_ERR(makeEvent(UR_COMMAND_KERNEL_LAUNCH, Queue, hQueue, phEvent)); return UR_RESULT_SUCCESS; } @@ -103,10 +208,9 @@ ur_result_t doMemcpy(ur_command_t Command, ur_queue_handle_t hQueue, size_t size, bool blocking, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - // Ignore wait list for now - (void)numEventsInWaitList; - (void)phEventWaitList; - // + ol_queue_handle_t Queue; + OL_RETURN_ON_ERR(hQueue->nextQueue(Queue)); + OL_RETURN_ON_ERR(waitOnEvents(Queue, phEventWaitList, numEventsInWaitList)); if (blocking) { OL_RETURN_ON_ERR( @@ -117,8 +221,6 @@ ur_result_t doMemcpy(ur_command_t Command, ur_queue_handle_t hQueue, return UR_RESULT_SUCCESS; } - ol_queue_handle_t Queue; - OL_RETURN_ON_ERR(hQueue->nextQueue(Queue)); OL_RETURN_ON_ERR( olMemcpy(Queue, DestPtr, DestDevice, SrcPtr, SrcDevice, size)); if (phEvent) { @@ -192,17 +294,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite( numEventsInWaitList, phEventWaitList, phEvent); } -ur_result_t enqueueNoOp(ur_command_t Type, ur_queue_handle_t hQueue, - ur_event_handle_t *phEvent) { - // This path is a no-op, but we can't output a real event because - // Offload doesn't currently support creating arbitrary events, and we - // don't know the last real event in the queue. Instead we just have to - // wait on the whole queue and then return an empty (implicitly - // finished) event. - *phEvent = ur_event_handle_t_::createEmptyEvent(Type, hQueue); - return urQueueFinish(hQueue); -} - UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap( ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingMap, ur_map_flags_t mapFlags, size_t offset, size_t size, @@ -226,15 +317,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap( Result = urEnqueueMemBufferRead(hQueue, hBuffer, blockingMap, offset, size, MapPtr, numEventsInWaitList, phEventWaitList, phEvent); - } else { - if (IsPinned) { - // TODO: Ignore the event waits list for now. When urEnqueueEventsWait is - // implemented we can call it on the wait list. - } - - if (phEvent) { - enqueueNoOp(UR_COMMAND_MEM_BUFFER_MAP, hQueue, phEvent); + } else if (numEventsInWaitList || phEvent) { + ol_queue_handle_t Queue; + OL_RETURN_ON_ERR(hQueue->nextQueue(Queue)); + if ((!hQueue->isInOrder() && phEvent) || hQueue->isInOrder()) { + // Out-of-order queues running no-op work only have side effects if there + // is an output event + waitOnEvents(Queue, phEventWaitList, numEventsInWaitList); } + OL_RETURN_ON_ERR( + makeEvent(UR_COMMAND_MEM_BUFFER_MAP, Queue, hQueue, phEvent)); } *ppRetMap = MapPtr; @@ -260,15 +352,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap( Result = urEnqueueMemBufferWrite( hQueue, hMem, true, Map->MapOffset, Map->MapSize, pMappedPtr, numEventsInWaitList, phEventWaitList, phEvent); - } else { - if (IsPinned) { - // TODO: Ignore the event waits list for now. When urEnqueueEventsWait is - // implemented we can call it on the wait list. - } - - if (phEvent) { - enqueueNoOp(UR_COMMAND_MEM_UNMAP, hQueue, phEvent); + } else if (numEventsInWaitList || phEvent) { + ol_queue_handle_t Queue; + OL_RETURN_ON_ERR(hQueue->nextQueue(Queue)); + if ((!hQueue->isInOrder() && phEvent) || hQueue->isInOrder()) { + // Out-of-order queues running no-op work only have side effects if there + // is an output event + waitOnEvents(Queue, phEventWaitList, numEventsInWaitList); } + OL_RETURN_ON_ERR(makeEvent(UR_COMMAND_MEM_UNMAP, Queue, hQueue, phEvent)); } BufferImpl.unmap(pMappedPtr); diff --git a/unified-runtime/source/adapters/offload/queue.hpp b/unified-runtime/source/adapters/offload/queue.hpp index 25585db273763..8f887a9c3be01 100644 --- a/unified-runtime/source/adapters/offload/queue.hpp +++ b/unified-runtime/source/adapters/offload/queue.hpp @@ -23,8 +23,8 @@ struct ur_queue_handle_t_ : RefCounted { : OffloadQueues((Flags & UR_QUEUE_FLAG_OUT_OF_ORDER_EXEC_MODE_ENABLE) ? 1 : OOO_QUEUE_POOL_SIZE), - QueueOffset(0), OffloadDevice(Device), UrContext(UrContext), - Flags(Flags) {} + QueueOffset(0), Barrier(nullptr), OffloadDevice(Device), + UrContext(UrContext), Flags(Flags) {} // In-order queues only have one element here, while out of order queues have // a bank of queues to use. We rotate through them round robin instead of @@ -35,22 +35,37 @@ struct ur_queue_handle_t_ : RefCounted { // `stream_queue_t`. In the future, if we want more performance or it // simplifies the implementation of a feature, we can consider using it. std::vector OffloadQueues; + // Mutex guarding the offset and barrier for out of order queues + std::mutex OooMutex; size_t QueueOffset; + ol_event_handle_t Barrier; ol_device_handle_t OffloadDevice; ur_context_handle_t UrContext; ur_queue_flags_t Flags; - ol_result_t nextQueue(ol_queue_handle_t &Handle) { - auto &Slot = OffloadQueues[QueueOffset++]; - QueueOffset %= OffloadQueues.size(); + bool isInOrder() const { return OffloadQueues.size() == 1; } + + ol_result_t nextQueueNoLock(ol_queue_handle_t &Handle) { + auto &Slot = OffloadQueues[(QueueOffset++) % OffloadQueues.size()]; if (!Slot) { if (auto Res = olCreateQueue(OffloadDevice, &Slot)) { return Res; } + + if (auto Event = Barrier) { + if (auto Res = olWaitEvents(Slot, &Event, 1)) { + return Res; + } + } } Handle = Slot; return nullptr; } + + ol_result_t nextQueue(ol_queue_handle_t &Handle) { + std::lock_guard Lock(OooMutex); + return nextQueueNoLock(Handle); + } }; diff --git a/unified-runtime/source/adapters/offload/ur_interface_loader.cpp b/unified-runtime/source/adapters/offload/ur_interface_loader.cpp index 02de9df99fddc..498b09d7daf92 100644 --- a/unified-runtime/source/adapters/offload/ur_interface_loader.cpp +++ b/unified-runtime/source/adapters/offload/ur_interface_loader.cpp @@ -170,8 +170,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable( } pDdiTable->pfnDeviceGlobalVariableRead = urEnqueueDeviceGlobalVariableRead; pDdiTable->pfnDeviceGlobalVariableWrite = urEnqueueDeviceGlobalVariableWrite; - pDdiTable->pfnEventsWait = nullptr; - pDdiTable->pfnEventsWaitWithBarrier = nullptr; + pDdiTable->pfnEventsWait = urEnqueueEventsWait; + pDdiTable->pfnEventsWaitWithBarrier = urEnqueueEventsWaitWithBarrier; pDdiTable->pfnKernelLaunch = urEnqueueKernelLaunch; pDdiTable->pfnMemBufferCopy = nullptr; pDdiTable->pfnMemBufferCopyRect = nullptr;