Skip to content

Commit fa31c6e

Browse files
authored
[UR][Offload] Various small fixes for offload adapter (#19832)
Some functions were added, the barrier event now is reference counted (rather than being dropped when the event is destroyed) and empty events no longer cause an error.
1 parent 665516d commit fa31c6e

File tree

7 files changed

+84
-18
lines changed

7 files changed

+84
-18
lines changed

unified-runtime/source/adapters/offload/context.hpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,28 @@
1717
#include <unordered_map>
1818
#include <ur_api.h>
1919

20+
struct alloc_info_t {
21+
ol_alloc_type_t Type;
22+
size_t Size;
23+
};
24+
2025
struct ur_context_handle_t_ : RefCounted {
2126
ur_context_handle_t_(ur_device_handle_t hDevice) : Device{hDevice} {
2227
urDeviceRetain(Device);
2328
}
2429
~ur_context_handle_t_() { urDeviceRelease(Device); }
2530

2631
ur_device_handle_t Device;
27-
std::unordered_map<void *, ol_alloc_type_t> AllocTypeMap;
32+
std::unordered_map<void *, alloc_info_t> AllocTypeMap;
33+
34+
std::optional<alloc_info_t> getAllocType(const void *UsmPtr) {
35+
for (auto &pair : AllocTypeMap) {
36+
if (UsmPtr >= pair.first &&
37+
reinterpret_cast<uintptr_t>(UsmPtr) <
38+
reinterpret_cast<uintptr_t>(pair.first) + pair.second.Size) {
39+
return pair.second;
40+
}
41+
}
42+
return std::nullopt;
43+
}
2844
};

unified-runtime/source/adapters/offload/enqueue.cpp

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,19 @@ ur_result_t doWait(ur_queue_handle_t hQueue, uint32_t numEventsInWaitList,
9393
OL_RETURN_ON_ERR(makeEvent(TYPE, TargetQueue, hQueue, phEvent));
9494

9595
if constexpr (Barrier) {
96-
ol_event_handle_t BarrierEvent;
96+
ur_event_handle_t BarrierEvent;
9797
if (phEvent) {
98-
BarrierEvent = (*phEvent)->OffloadEvent;
98+
BarrierEvent = *phEvent;
99+
urEventRetain(BarrierEvent);
99100
} else {
100-
OL_RETURN_ON_ERR(olCreateEvent(TargetQueue, &BarrierEvent));
101+
OL_RETURN_ON_ERR(makeEvent(TYPE, TargetQueue, hQueue, &BarrierEvent));
101102
}
102103

103104
// Ensure any newly created work waits on this barrier
104105
if (hQueue->Barrier) {
105-
OL_RETURN_ON_ERR(olDestroyEvent(hQueue->Barrier));
106+
if (auto Err = urEventRelease(hQueue->Barrier)) {
107+
return Err;
108+
}
106109
}
107110
hQueue->Barrier = BarrierEvent;
108111

@@ -114,7 +117,7 @@ ur_result_t doWait(ur_queue_handle_t hQueue, uint32_t numEventsInWaitList,
114117
if (Q == TargetQueue) {
115118
continue;
116119
}
117-
OL_RETURN_ON_ERR(olWaitEvents(Q, &BarrierEvent, 1));
120+
OL_RETURN_ON_ERR(olWaitEvents(Q, &BarrierEvent->OffloadEvent, 1));
118121
}
119122
}
120123

@@ -260,6 +263,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
260263
blockingWrite, numEventsInWaitList, phEventWaitList, phEvent);
261264
}
262265

266+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
267+
ur_queue_handle_t hQueue, ur_mem_handle_t hBufferSrc,
268+
ur_mem_handle_t hBufferDst, size_t srcOffset, size_t dstOffset, size_t size,
269+
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
270+
ur_event_handle_t *phEvent) {
271+
char *DevPtrSrc =
272+
reinterpret_cast<char *>(std::get<BufferMem>(hBufferSrc->Mem).Ptr);
273+
char *DevPtrDst =
274+
reinterpret_cast<char *>(std::get<BufferMem>(hBufferDst->Mem).Ptr);
275+
276+
return doMemcpy(UR_COMMAND_MEM_BUFFER_COPY, hQueue, DevPtrDst + dstOffset,
277+
hQueue->OffloadDevice, DevPtrSrc + srcOffset,
278+
hQueue->OffloadDevice, size, false, numEventsInWaitList,
279+
phEventWaitList, phEvent);
280+
}
281+
263282
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead(
264283
ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
265284
bool blockingRead, size_t count, size_t offset, void *pDst,
@@ -366,3 +385,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
366385

367386
return Result;
368387
}
388+
389+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
390+
ur_queue_handle_t hQueue, bool blocking, void *pDst, const void *pSrc,
391+
size_t size, uint32_t numEventsInWaitList,
392+
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
393+
auto GetDevice = [&](const void *Ptr) {
394+
auto Res = hQueue->UrContext->getAllocType(Ptr);
395+
if (!Res)
396+
return Adapter->HostDevice;
397+
return Res->Type == OL_ALLOC_TYPE_HOST ? Adapter->HostDevice
398+
: hQueue->OffloadDevice;
399+
};
400+
401+
return doMemcpy(UR_COMMAND_USM_MEMCPY, hQueue, pDst, GetDevice(pDst), pSrc,
402+
GetDevice(pSrc), size, blocking, numEventsInWaitList,
403+
phEventWaitList, phEvent);
404+
405+
return UR_RESULT_SUCCESS;
406+
}

unified-runtime/source/adapters/offload/event.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) {
6464

6565
UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) {
6666
if (--hEvent->RefCount == 0) {
67-
auto Res = olDestroyEvent(hEvent->OffloadEvent);
68-
if (Res) {
69-
return offloadResultToUR(Res);
67+
if (hEvent->OffloadEvent) {
68+
auto Res = olDestroyEvent(hEvent->OffloadEvent);
69+
if (Res) {
70+
return offloadResultToUR(Res);
71+
}
7072
}
7173
delete hEvent;
7274
}

unified-runtime/source/adapters/offload/queue.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
105105
const ur_queue_native_properties_t *, ur_queue_handle_t *) {
106106
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
107107
}
108+
109+
UR_APIEXPORT ur_result_t UR_APICALL urQueueFlush(ur_queue_handle_t) {
110+
return UR_RESULT_SUCCESS;
111+
}

unified-runtime/source/adapters/offload/queue.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <ur_api.h>
1515

1616
#include "common.hpp"
17+
#include "event.hpp"
1718

1819
constexpr size_t OOO_QUEUE_POOL_SIZE = 32;
1920

@@ -38,7 +39,7 @@ struct ur_queue_handle_t_ : RefCounted {
3839
// Mutex guarding the offset and barrier for out of order queues
3940
std::mutex OooMutex;
4041
size_t QueueOffset;
41-
ol_event_handle_t Barrier;
42+
ur_event_handle_t Barrier;
4243
ol_device_handle_t OffloadDevice;
4344
ur_context_handle_t UrContext;
4445
ur_queue_flags_t Flags;
@@ -54,7 +55,7 @@ struct ur_queue_handle_t_ : RefCounted {
5455
}
5556

5657
if (auto Event = Barrier) {
57-
if (auto Res = olWaitEvents(Slot, &Event, 1)) {
58+
if (auto Res = olWaitEvents(Slot, &Event->OffloadEvent, 1)) {
5859
return Res;
5960
}
6061
}

unified-runtime/source/adapters/offload/ur_interface_loader.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable(
173173
pDdiTable->pfnEventsWait = urEnqueueEventsWait;
174174
pDdiTable->pfnEventsWaitWithBarrier = urEnqueueEventsWaitWithBarrier;
175175
pDdiTable->pfnKernelLaunch = urEnqueueKernelLaunch;
176-
pDdiTable->pfnMemBufferCopy = nullptr;
176+
pDdiTable->pfnMemBufferCopy = urEnqueueMemBufferCopy;
177177
pDdiTable->pfnMemBufferCopyRect = nullptr;
178178
pDdiTable->pfnMemBufferFill = nullptr;
179179
pDdiTable->pfnMemBufferMap = urEnqueueMemBufferMap;
@@ -189,7 +189,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable(
189189
pDdiTable->pfnUSMFill = nullptr;
190190
pDdiTable->pfnUSMAdvise = nullptr;
191191
pDdiTable->pfnUSMMemcpy2D = urEnqueueUSMMemcpy2D;
192-
pDdiTable->pfnUSMMemcpy = nullptr;
192+
pDdiTable->pfnUSMMemcpy = urEnqueueUSMMemcpy;
193193
pDdiTable->pfnUSMPrefetch = nullptr;
194194
pDdiTable->pfnReadHostPipe = nullptr;
195195
pDdiTable->pfnWriteHostPipe = nullptr;
@@ -221,7 +221,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetQueueProcAddrTable(
221221
pDdiTable->pfnCreate = urQueueCreate;
222222
pDdiTable->pfnCreateWithNativeHandle = urQueueCreateWithNativeHandle;
223223
pDdiTable->pfnFinish = urQueueFinish;
224-
pDdiTable->pfnFlush = nullptr;
224+
pDdiTable->pfnFlush = urQueueFlush;
225225
pDdiTable->pfnGetInfo = urQueueGetInfo;
226226
pDdiTable->pfnGetNativeHandle = urQueueGetNativeHandle;
227227
pDdiTable->pfnRelease = urQueueRelease;

unified-runtime/source/adapters/offload/usm.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(ur_context_handle_t hContext,
2323
OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice,
2424
OL_ALLOC_TYPE_HOST, size, ppMem));
2525

26-
hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_HOST);
26+
hContext->AllocTypeMap.insert_or_assign(
27+
*ppMem, alloc_info_t{OL_ALLOC_TYPE_HOST, size});
2728
return UR_RESULT_SUCCESS;
2829
}
2930

@@ -33,7 +34,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc(
3334
OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice,
3435
OL_ALLOC_TYPE_DEVICE, size, ppMem));
3536

36-
hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_DEVICE);
37+
hContext->AllocTypeMap.insert_or_assign(
38+
*ppMem, alloc_info_t{OL_ALLOC_TYPE_DEVICE, size});
3739
return UR_RESULT_SUCCESS;
3840
}
3941

@@ -43,10 +45,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc(
4345
OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice,
4446
OL_ALLOC_TYPE_MANAGED, size, ppMem));
4547

46-
hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_MANAGED);
48+
hContext->AllocTypeMap.insert_or_assign(
49+
*ppMem, alloc_info_t{OL_ALLOC_TYPE_MANAGED, size});
4750
return UR_RESULT_SUCCESS;
4851
}
4952

50-
UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t, void *pMem) {
53+
UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t hContext,
54+
void *pMem) {
55+
hContext->AllocTypeMap.erase(pMem);
5156
return offloadResultToUR(olMemFree(pMem));
5257
}

0 commit comments

Comments
 (0)