Skip to content

Commit bf45c90

Browse files
committed
[UR][Offload] Various small fixes for offload adapter
Some functions were added, the barrier event now is reference counted (rather than being dropped when the event is destroyed) and empty events no longer cause an error.
1 parent ea192fd commit bf45c90

File tree

7 files changed

+84
-18
lines changed

7 files changed

+84
-18
lines changed

unified-runtime/source/adapters/offload/context.hpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,28 @@
1717
#include <unordered_map>
1818
#include <ur_api.h>
1919

20+
struct alloc_info_t {
21+
ol_alloc_type_t Type;
22+
size_t Size;
23+
};
24+
2025
struct ur_context_handle_t_ : RefCounted {
2126
ur_context_handle_t_(ur_device_handle_t hDevice) : Device{hDevice} {
2227
urDeviceRetain(Device);
2328
}
2429
~ur_context_handle_t_() { urDeviceRelease(Device); }
2530

2631
ur_device_handle_t Device;
27-
std::unordered_map<void *, ol_alloc_type_t> AllocTypeMap;
32+
std::unordered_map<void *, alloc_info_t> AllocTypeMap;
33+
34+
std::optional<alloc_info_t> getAllocType(const void *UsmPtr) {
35+
for (auto &pair : AllocTypeMap) {
36+
if (UsmPtr >= pair.first &&
37+
reinterpret_cast<uintptr_t>(UsmPtr) <
38+
reinterpret_cast<uintptr_t>(pair.first) + pair.second.Size) {
39+
return pair.second;
40+
}
41+
}
42+
return std::nullopt;
43+
}
2844
};

unified-runtime/source/adapters/offload/enqueue.cpp

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,19 @@ ur_result_t doWait(ur_queue_handle_t hQueue, uint32_t numEventsInWaitList,
9393
OL_RETURN_ON_ERR(makeEvent(TYPE, TargetQueue, hQueue, phEvent));
9494

9595
if constexpr (Barrier) {
96-
ol_event_handle_t BarrierEvent;
96+
ur_event_handle_t BarrierEvent;
9797
if (phEvent) {
98-
BarrierEvent = (*phEvent)->OffloadEvent;
98+
BarrierEvent = *phEvent;
99+
urEventRetain(BarrierEvent);
99100
} else {
100-
OL_RETURN_ON_ERR(olCreateEvent(TargetQueue, &BarrierEvent));
101+
OL_RETURN_ON_ERR(makeEvent(TYPE, TargetQueue, hQueue, &BarrierEvent));
101102
}
102103

103104
// Ensure any newly created work waits on this barrier
104105
if (hQueue->Barrier) {
105-
OL_RETURN_ON_ERR(olDestroyEvent(hQueue->Barrier));
106+
if (auto Err = urEventRelease(hQueue->Barrier)) {
107+
return Err;
108+
}
106109
}
107110
hQueue->Barrier = BarrierEvent;
108111

@@ -114,7 +117,7 @@ ur_result_t doWait(ur_queue_handle_t hQueue, uint32_t numEventsInWaitList,
114117
if (Q == TargetQueue) {
115118
continue;
116119
}
117-
OL_RETURN_ON_ERR(olWaitEvents(Q, &BarrierEvent, 1));
120+
OL_RETURN_ON_ERR(olWaitEvents(Q, &BarrierEvent->OffloadEvent, 1));
118121
}
119122
}
120123

@@ -260,6 +263,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
260263
blockingWrite, numEventsInWaitList, phEventWaitList, phEvent);
261264
}
262265

266+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
267+
ur_queue_handle_t hQueue, ur_mem_handle_t hBufferSrc,
268+
ur_mem_handle_t hBufferDst, size_t srcOffset, size_t dstOffset, size_t size,
269+
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
270+
ur_event_handle_t *phEvent) {
271+
char *DevPtrSrc =
272+
reinterpret_cast<char *>(std::get<BufferMem>(hBufferSrc->Mem).Ptr);
273+
char *DevPtrDst =
274+
reinterpret_cast<char *>(std::get<BufferMem>(hBufferDst->Mem).Ptr);
275+
276+
return doMemcpy(UR_COMMAND_MEM_BUFFER_COPY, hQueue, DevPtrDst + dstOffset,
277+
hQueue->OffloadDevice, DevPtrSrc + srcOffset,
278+
hQueue->OffloadDevice, size, false, numEventsInWaitList,
279+
phEventWaitList, phEvent);
280+
}
281+
263282
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead(
264283
ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
265284
bool blockingRead, size_t count, size_t offset, void *pDst,
@@ -366,3 +385,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
366385

367386
return Result;
368387
}
388+
389+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
390+
ur_queue_handle_t hQueue, bool blocking, void *pDst, const void *pSrc,
391+
size_t size, uint32_t numEventsInWaitList,
392+
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
393+
auto GetDevice = [&](const void *Ptr) {
394+
auto Res = hQueue->UrContext->getAllocType(Ptr);
395+
if (!Res)
396+
return Adapter->HostDevice;
397+
return Res->Type == OL_ALLOC_TYPE_HOST ? Adapter->HostDevice
398+
: hQueue->OffloadDevice;
399+
};
400+
401+
return doMemcpy(UR_COMMAND_USM_MEMCPY, hQueue, pDst, GetDevice(pDst), pSrc,
402+
GetDevice(pSrc), size, blocking, numEventsInWaitList,
403+
phEventWaitList, phEvent);
404+
405+
return UR_RESULT_SUCCESS;
406+
}

unified-runtime/source/adapters/offload/event.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) {
6464

6565
UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) {
6666
if (--hEvent->RefCount == 0) {
67-
auto Res = olDestroyEvent(hEvent->OffloadEvent);
68-
if (Res) {
69-
return offloadResultToUR(Res);
67+
if (hEvent->OffloadEvent) {
68+
auto Res = olDestroyEvent(hEvent->OffloadEvent);
69+
if (Res) {
70+
return offloadResultToUR(Res);
71+
}
7072
}
7173
delete hEvent;
7274
}

unified-runtime/source/adapters/offload/queue.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
105105
const ur_queue_native_properties_t *, ur_queue_handle_t *) {
106106
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
107107
}
108+
109+
UR_APIEXPORT ur_result_t UR_APICALL urQueueFlush(ur_queue_handle_t) {
110+
return UR_RESULT_SUCCESS;
111+
}

unified-runtime/source/adapters/offload/queue.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <ur_api.h>
1515

1616
#include "common.hpp"
17+
#include "event.hpp"
1718

1819
constexpr size_t OOO_QUEUE_POOL_SIZE = 32;
1920

@@ -38,7 +39,7 @@ struct ur_queue_handle_t_ : RefCounted {
3839
// Mutex guarding the offset and barrier for out of order queues
3940
std::mutex OooMutex;
4041
size_t QueueOffset;
41-
ol_event_handle_t Barrier;
42+
ur_event_handle_t Barrier;
4243
ol_device_handle_t OffloadDevice;
4344
ur_context_handle_t UrContext;
4445
ur_queue_flags_t Flags;
@@ -54,7 +55,7 @@ struct ur_queue_handle_t_ : RefCounted {
5455
}
5556

5657
if (auto Event = Barrier) {
57-
if (auto Res = olWaitEvents(Slot, &Event, 1)) {
58+
if (auto Res = olWaitEvents(Slot, &Event->OffloadEvent, 1)) {
5859
return Res;
5960
}
6061
}

unified-runtime/source/adapters/offload/ur_interface_loader.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable(
173173
pDdiTable->pfnEventsWait = urEnqueueEventsWait;
174174
pDdiTable->pfnEventsWaitWithBarrier = urEnqueueEventsWaitWithBarrier;
175175
pDdiTable->pfnKernelLaunch = urEnqueueKernelLaunch;
176-
pDdiTable->pfnMemBufferCopy = nullptr;
176+
pDdiTable->pfnMemBufferCopy = urEnqueueMemBufferCopy;
177177
pDdiTable->pfnMemBufferCopyRect = nullptr;
178178
pDdiTable->pfnMemBufferFill = nullptr;
179179
pDdiTable->pfnMemBufferMap = urEnqueueMemBufferMap;
@@ -189,7 +189,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable(
189189
pDdiTable->pfnUSMFill = nullptr;
190190
pDdiTable->pfnUSMAdvise = nullptr;
191191
pDdiTable->pfnUSMMemcpy2D = urEnqueueUSMMemcpy2D;
192-
pDdiTable->pfnUSMMemcpy = nullptr;
192+
pDdiTable->pfnUSMMemcpy = urEnqueueUSMMemcpy;
193193
pDdiTable->pfnUSMPrefetch = nullptr;
194194
pDdiTable->pfnReadHostPipe = nullptr;
195195
pDdiTable->pfnWriteHostPipe = nullptr;
@@ -221,7 +221,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetQueueProcAddrTable(
221221
pDdiTable->pfnCreate = urQueueCreate;
222222
pDdiTable->pfnCreateWithNativeHandle = urQueueCreateWithNativeHandle;
223223
pDdiTable->pfnFinish = urQueueFinish;
224-
pDdiTable->pfnFlush = nullptr;
224+
pDdiTable->pfnFlush = urQueueFlush;
225225
pDdiTable->pfnGetInfo = urQueueGetInfo;
226226
pDdiTable->pfnGetNativeHandle = urQueueGetNativeHandle;
227227
pDdiTable->pfnRelease = urQueueRelease;

unified-runtime/source/adapters/offload/usm.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(ur_context_handle_t hContext,
2323
OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice,
2424
OL_ALLOC_TYPE_HOST, size, ppMem));
2525

26-
hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_HOST);
26+
hContext->AllocTypeMap.insert_or_assign(
27+
*ppMem, alloc_info_t{OL_ALLOC_TYPE_HOST, size});
2728
return UR_RESULT_SUCCESS;
2829
}
2930

@@ -33,7 +34,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc(
3334
OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice,
3435
OL_ALLOC_TYPE_DEVICE, size, ppMem));
3536

36-
hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_DEVICE);
37+
hContext->AllocTypeMap.insert_or_assign(
38+
*ppMem, alloc_info_t{OL_ALLOC_TYPE_DEVICE, size});
3739
return UR_RESULT_SUCCESS;
3840
}
3941

@@ -43,10 +45,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc(
4345
OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice,
4446
OL_ALLOC_TYPE_MANAGED, size, ppMem));
4547

46-
hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_MANAGED);
48+
hContext->AllocTypeMap.insert_or_assign(
49+
*ppMem, alloc_info_t{OL_ALLOC_TYPE_MANAGED, size});
4750
return UR_RESULT_SUCCESS;
4851
}
4952

50-
UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t, void *pMem) {
53+
UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t hContext,
54+
void *pMem) {
55+
hContext->AllocTypeMap.erase(pMem);
5156
return offloadResultToUR(olMemFree(pMem));
5257
}

0 commit comments

Comments
 (0)