diff --git a/unified-runtime/source/adapters/offload/context.hpp b/unified-runtime/source/adapters/offload/context.hpp index 38857446c47f8..b40d17ad3ae9c 100644 --- a/unified-runtime/source/adapters/offload/context.hpp +++ b/unified-runtime/source/adapters/offload/context.hpp @@ -17,6 +17,11 @@ #include #include +struct alloc_info_t { + ol_alloc_type_t Type; + size_t Size; +}; + struct ur_context_handle_t_ : RefCounted { ur_context_handle_t_(ur_device_handle_t hDevice) : Device{hDevice} { urDeviceRetain(Device); @@ -24,5 +29,16 @@ struct ur_context_handle_t_ : RefCounted { ~ur_context_handle_t_() { urDeviceRelease(Device); } ur_device_handle_t Device; - std::unordered_map AllocTypeMap; + std::unordered_map AllocTypeMap; + + std::optional getAllocType(const void *UsmPtr) { + for (auto &pair : AllocTypeMap) { + if (UsmPtr >= pair.first && + reinterpret_cast(UsmPtr) < + reinterpret_cast(pair.first) + pair.second.Size) { + return pair.second; + } + } + return std::nullopt; + } }; diff --git a/unified-runtime/source/adapters/offload/enqueue.cpp b/unified-runtime/source/adapters/offload/enqueue.cpp index 6b9a013538bb4..87bd45eb3f817 100644 --- a/unified-runtime/source/adapters/offload/enqueue.cpp +++ b/unified-runtime/source/adapters/offload/enqueue.cpp @@ -93,16 +93,19 @@ ur_result_t doWait(ur_queue_handle_t hQueue, uint32_t numEventsInWaitList, OL_RETURN_ON_ERR(makeEvent(TYPE, TargetQueue, hQueue, phEvent)); if constexpr (Barrier) { - ol_event_handle_t BarrierEvent; + ur_event_handle_t BarrierEvent; if (phEvent) { - BarrierEvent = (*phEvent)->OffloadEvent; + BarrierEvent = *phEvent; + urEventRetain(BarrierEvent); } else { - OL_RETURN_ON_ERR(olCreateEvent(TargetQueue, &BarrierEvent)); + OL_RETURN_ON_ERR(makeEvent(TYPE, TargetQueue, hQueue, &BarrierEvent)); } // Ensure any newly created work waits on this barrier if (hQueue->Barrier) { - OL_RETURN_ON_ERR(olDestroyEvent(hQueue->Barrier)); + if (auto Err = urEventRelease(hQueue->Barrier)) { + return Err; + } } hQueue->Barrier = BarrierEvent; @@ -114,7 +117,7 @@ ur_result_t doWait(ur_queue_handle_t hQueue, uint32_t numEventsInWaitList, if (Q == TargetQueue) { continue; } - OL_RETURN_ON_ERR(olWaitEvents(Q, &BarrierEvent, 1)); + OL_RETURN_ON_ERR(olWaitEvents(Q, &BarrierEvent->OffloadEvent, 1)); } } @@ -260,6 +263,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite( blockingWrite, numEventsInWaitList, phEventWaitList, phEvent); } +UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy( + ur_queue_handle_t hQueue, ur_mem_handle_t hBufferSrc, + ur_mem_handle_t hBufferDst, size_t srcOffset, size_t dstOffset, size_t size, + uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, + ur_event_handle_t *phEvent) { + char *DevPtrSrc = + reinterpret_cast(std::get(hBufferSrc->Mem).Ptr); + char *DevPtrDst = + reinterpret_cast(std::get(hBufferDst->Mem).Ptr); + + return doMemcpy(UR_COMMAND_MEM_BUFFER_COPY, hQueue, DevPtrDst + dstOffset, + hQueue->OffloadDevice, DevPtrSrc + srcOffset, + hQueue->OffloadDevice, size, false, numEventsInWaitList, + phEventWaitList, phEvent); +} + UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead( ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name, bool blockingRead, size_t count, size_t offset, void *pDst, @@ -366,3 +385,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap( return Result; } + +UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy( + ur_queue_handle_t hQueue, bool blocking, void *pDst, const void *pSrc, + size_t size, uint32_t numEventsInWaitList, + const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { + auto GetDevice = [&](const void *Ptr) { + auto Res = hQueue->UrContext->getAllocType(Ptr); + if (!Res) + return Adapter->HostDevice; + return Res->Type == OL_ALLOC_TYPE_HOST ? Adapter->HostDevice + : hQueue->OffloadDevice; + }; + + return doMemcpy(UR_COMMAND_USM_MEMCPY, hQueue, pDst, GetDevice(pDst), pSrc, + GetDevice(pSrc), size, blocking, numEventsInWaitList, + phEventWaitList, phEvent); + + return UR_RESULT_SUCCESS; +} diff --git a/unified-runtime/source/adapters/offload/event.cpp b/unified-runtime/source/adapters/offload/event.cpp index aab41ed3d2d0e..ee326df79dd6f 100644 --- a/unified-runtime/source/adapters/offload/event.cpp +++ b/unified-runtime/source/adapters/offload/event.cpp @@ -64,9 +64,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) { UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) { if (--hEvent->RefCount == 0) { - auto Res = olDestroyEvent(hEvent->OffloadEvent); - if (Res) { - return offloadResultToUR(Res); + if (hEvent->OffloadEvent) { + auto Res = olDestroyEvent(hEvent->OffloadEvent); + if (Res) { + return offloadResultToUR(Res); + } } delete hEvent; } diff --git a/unified-runtime/source/adapters/offload/queue.cpp b/unified-runtime/source/adapters/offload/queue.cpp index 43647d0041496..26a5d34e2ed0c 100644 --- a/unified-runtime/source/adapters/offload/queue.cpp +++ b/unified-runtime/source/adapters/offload/queue.cpp @@ -105,3 +105,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreateWithNativeHandle( const ur_queue_native_properties_t *, ur_queue_handle_t *) { return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } + +UR_APIEXPORT ur_result_t UR_APICALL urQueueFlush(ur_queue_handle_t) { + return UR_RESULT_SUCCESS; +} diff --git a/unified-runtime/source/adapters/offload/queue.hpp b/unified-runtime/source/adapters/offload/queue.hpp index 8f887a9c3be01..a7106b2411939 100644 --- a/unified-runtime/source/adapters/offload/queue.hpp +++ b/unified-runtime/source/adapters/offload/queue.hpp @@ -14,6 +14,7 @@ #include #include "common.hpp" +#include "event.hpp" constexpr size_t OOO_QUEUE_POOL_SIZE = 32; @@ -38,7 +39,7 @@ struct ur_queue_handle_t_ : RefCounted { // Mutex guarding the offset and barrier for out of order queues std::mutex OooMutex; size_t QueueOffset; - ol_event_handle_t Barrier; + ur_event_handle_t Barrier; ol_device_handle_t OffloadDevice; ur_context_handle_t UrContext; ur_queue_flags_t Flags; @@ -54,7 +55,7 @@ struct ur_queue_handle_t_ : RefCounted { } if (auto Event = Barrier) { - if (auto Res = olWaitEvents(Slot, &Event, 1)) { + if (auto Res = olWaitEvents(Slot, &Event->OffloadEvent, 1)) { return Res; } } diff --git a/unified-runtime/source/adapters/offload/ur_interface_loader.cpp b/unified-runtime/source/adapters/offload/ur_interface_loader.cpp index 5b4c8bd13bc50..145f78d0f90b9 100644 --- a/unified-runtime/source/adapters/offload/ur_interface_loader.cpp +++ b/unified-runtime/source/adapters/offload/ur_interface_loader.cpp @@ -173,7 +173,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable( pDdiTable->pfnEventsWait = urEnqueueEventsWait; pDdiTable->pfnEventsWaitWithBarrier = urEnqueueEventsWaitWithBarrier; pDdiTable->pfnKernelLaunch = urEnqueueKernelLaunch; - pDdiTable->pfnMemBufferCopy = nullptr; + pDdiTable->pfnMemBufferCopy = urEnqueueMemBufferCopy; pDdiTable->pfnMemBufferCopyRect = nullptr; pDdiTable->pfnMemBufferFill = nullptr; pDdiTable->pfnMemBufferMap = urEnqueueMemBufferMap; @@ -189,7 +189,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable( pDdiTable->pfnUSMFill = nullptr; pDdiTable->pfnUSMAdvise = nullptr; pDdiTable->pfnUSMMemcpy2D = urEnqueueUSMMemcpy2D; - pDdiTable->pfnUSMMemcpy = nullptr; + pDdiTable->pfnUSMMemcpy = urEnqueueUSMMemcpy; pDdiTable->pfnUSMPrefetch = nullptr; pDdiTable->pfnReadHostPipe = nullptr; pDdiTable->pfnWriteHostPipe = nullptr; @@ -221,7 +221,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetQueueProcAddrTable( pDdiTable->pfnCreate = urQueueCreate; pDdiTable->pfnCreateWithNativeHandle = urQueueCreateWithNativeHandle; pDdiTable->pfnFinish = urQueueFinish; - pDdiTable->pfnFlush = nullptr; + pDdiTable->pfnFlush = urQueueFlush; pDdiTable->pfnGetInfo = urQueueGetInfo; pDdiTable->pfnGetNativeHandle = urQueueGetNativeHandle; pDdiTable->pfnRelease = urQueueRelease; diff --git a/unified-runtime/source/adapters/offload/usm.cpp b/unified-runtime/source/adapters/offload/usm.cpp index 99f7931e9ddd7..f427689618d69 100644 --- a/unified-runtime/source/adapters/offload/usm.cpp +++ b/unified-runtime/source/adapters/offload/usm.cpp @@ -23,7 +23,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(ur_context_handle_t hContext, OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice, OL_ALLOC_TYPE_HOST, size, ppMem)); - hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_HOST); + hContext->AllocTypeMap.insert_or_assign( + *ppMem, alloc_info_t{OL_ALLOC_TYPE_HOST, size}); return UR_RESULT_SUCCESS; } @@ -33,7 +34,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc( OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice, OL_ALLOC_TYPE_DEVICE, size, ppMem)); - hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_DEVICE); + hContext->AllocTypeMap.insert_or_assign( + *ppMem, alloc_info_t{OL_ALLOC_TYPE_DEVICE, size}); return UR_RESULT_SUCCESS; } @@ -43,10 +45,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc( OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice, OL_ALLOC_TYPE_MANAGED, size, ppMem)); - hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_MANAGED); + hContext->AllocTypeMap.insert_or_assign( + *ppMem, alloc_info_t{OL_ALLOC_TYPE_MANAGED, size}); return UR_RESULT_SUCCESS; } -UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t, void *pMem) { +UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t hContext, + void *pMem) { + hContext->AllocTypeMap.erase(pMem); return offloadResultToUR(olMemFree(pMem)); }