Skip to content

Commit 15a2787

Browse files
author
Hugh Delaney
committed
Get device from queue, not event
If an event was constructed using some interop method it doesn't have an associated queue. Implicit migration of buffers was working on the assumption that each event has an associated device, which can be obtained from the queue. This patch removes these assumptions.
1 parent 4c69624 commit 15a2787

File tree

4 files changed

+31
-27
lines changed

4 files changed

+31
-27
lines changed

source/adapters/hip/enqueue.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
extern size_t imageElementByteSize(hipArray_Format ArrayFormat);
2222

23-
ur_result_t enqueueEventsWait(ur_queue_handle_t, hipStream_t Stream,
23+
ur_result_t enqueueEventsWait(ur_queue_handle_t Queue, hipStream_t Stream,
2424
uint32_t NumEventsInWaitList,
2525
const ur_event_handle_t *EventWaitList) {
2626
if (!EventWaitList) {
@@ -29,8 +29,8 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t, hipStream_t Stream,
2929
try {
3030
auto Result = forLatestEvents(
3131
EventWaitList, NumEventsInWaitList,
32-
[Stream](ur_event_handle_t Event) -> ur_result_t {
33-
ScopedContext Active(Event->getDevice());
32+
[Stream, Queue](ur_event_handle_t Event) -> ur_result_t {
33+
ScopedContext Active(Queue->getDevice());
3434
if (Event->isCompleted() || Event->getStream() == Stream) {
3535
return UR_RESULT_SUCCESS;
3636
} else {
@@ -218,8 +218,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
218218
// last queue to write to the MemBuffer, meaning we must perform the copy
219219
// from a different device
220220
if (hBuffer->LastEventWritingToMemObj &&
221-
hBuffer->LastEventWritingToMemObj->getDevice() != hQueue->getDevice()) {
222-
Device = hBuffer->LastEventWritingToMemObj->getDevice();
221+
hBuffer->LastDeviceWritingToMemObj != hQueue->getDevice()) {
222+
Device = hBuffer->LastDeviceWritingToMemObj;
223223
ScopedContext Active(Device);
224224
HIPStream = hipStream_t{0}; // Default stream for different device
225225
// We may have to wait for an event on another queue if it is the last
@@ -367,7 +367,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
367367
// if it has been written to
368368
if (phEvent && (MemArg.AccessFlags &
369369
(UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY))) {
370-
MemArg.Mem->setLastEventWritingToMemObj(RetImplEvent.get());
370+
MemArg.Mem->setLastEventWritingToMemObj(RetImplEvent.get(),
371+
hQueue->getDevice());
371372
}
372373
}
373374
// We can release the MemoryMigrationMutexes now
@@ -584,8 +585,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
584585
// last queue to write to the MemBuffer, meaning we must perform the copy
585586
// from a different device
586587
if (hBuffer->LastEventWritingToMemObj &&
587-
hBuffer->LastEventWritingToMemObj->getDevice() != hQueue->getDevice()) {
588-
Device = hBuffer->LastEventWritingToMemObj->getDevice();
588+
hBuffer->LastDeviceWritingToMemObj != hQueue->getDevice()) {
589+
Device = hBuffer->LastDeviceWritingToMemObj;
589590
ScopedContext Active(Device);
590591
HIPStream = hipStream_t{0}; // Default stream for different device
591592
// We may have to wait for an event on another queue if it is the last
@@ -1017,8 +1018,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
10171018
// last queue to write to the MemBuffer, meaning we must perform the copy
10181019
// from a different device
10191020
if (hImage->LastEventWritingToMemObj &&
1020-
hImage->LastEventWritingToMemObj->getDevice() != hQueue->getDevice()) {
1021-
Device = hImage->LastEventWritingToMemObj->getDevice();
1021+
hImage->LastDeviceWritingToMemObj != hQueue->getDevice()) {
1022+
Device = hImage->LastDeviceWritingToMemObj;
10221023
ScopedContext Active(Device);
10231024
HIPStream = hipStream_t{0}; // Default stream for different device
10241025
// We may have to wait for an event on another queue if it is the last

source/adapters/hip/event.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ struct ur_event_handle_t_ {
2828

2929
ur_queue_handle_t getQueue() const noexcept { return Queue; }
3030

31-
ur_device_handle_t getDevice() const noexcept { return Queue->getDevice(); }
32-
3331
hipStream_t getStream() const noexcept { return Stream; }
3432

3533
uint32_t getComputeStreamToken() const noexcept { return StreamToken; }

source/adapters/hip/memory.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -525,11 +525,10 @@ inline ur_result_t migrateBufferToDevice(ur_mem_handle_t Mem,
525525
UR_CHECK_ERROR(
526526
hipMemcpyHtoD(Buffer.getPtr(hDevice), Buffer.HostPtr, Buffer.Size));
527527
}
528-
} else if (Mem->LastEventWritingToMemObj->getDevice() != hDevice) {
529-
UR_CHECK_ERROR(
530-
hipMemcpyDtoD(Buffer.getPtr(hDevice),
531-
Buffer.getPtr(Mem->LastEventWritingToMemObj->getDevice()),
532-
Buffer.Size));
528+
} else if (Mem->LastDeviceWritingToMemObj != hDevice) {
529+
UR_CHECK_ERROR(hipMemcpyDtoD(Buffer.getPtr(hDevice),
530+
Buffer.getPtr(Mem->LastDeviceWritingToMemObj),
531+
Buffer.Size));
533532
}
534533
return UR_RESULT_SUCCESS;
535534
}
@@ -577,22 +576,19 @@ inline ur_result_t migrateImageToDevice(ur_mem_handle_t Mem,
577576
CpyDesc3D.srcHost = Image.HostPtr;
578577
UR_CHECK_ERROR(hipDrvMemcpy3D(&CpyDesc3D));
579578
}
580-
} else if (Mem->LastEventWritingToMemObj->getDevice() != hDevice) {
579+
} else if (Mem->LastDeviceWritingToMemObj != hDevice) {
581580
if (Image.ImageDesc.type == UR_MEM_TYPE_IMAGE1D) {
582581
// FIXME: 1D memcpy from DtoD going through the host.
583582
UR_CHECK_ERROR(hipMemcpyAtoH(
584-
Image.HostPtr,
585-
Image.getArray(Mem->LastEventWritingToMemObj->getDevice()),
583+
Image.HostPtr, Image.getArray(Mem->LastDeviceWritingToMemObj),
586584
0 /*srcOffset*/, ImageSizeBytes));
587585
UR_CHECK_ERROR(
588586
hipMemcpyHtoA(ImageArray, 0, Image.HostPtr, ImageSizeBytes));
589587
} else if (Image.ImageDesc.type == UR_MEM_TYPE_IMAGE2D) {
590-
CpyDesc2D.srcArray =
591-
Image.getArray(Mem->LastEventWritingToMemObj->getDevice());
588+
CpyDesc2D.srcArray = Image.getArray(Mem->LastDeviceWritingToMemObj);
592589
UR_CHECK_ERROR(hipMemcpyParam2D(&CpyDesc2D));
593590
} else if (Image.ImageDesc.type == UR_MEM_TYPE_IMAGE3D) {
594-
CpyDesc3D.srcArray =
595-
Image.getArray(Mem->LastEventWritingToMemObj->getDevice());
591+
CpyDesc3D.srcArray = Image.getArray(Mem->LastDeviceWritingToMemObj);
596592
UR_CHECK_ERROR(hipDrvMemcpy3D(&CpyDesc3D));
597593
}
598594
}

source/adapters/hip/memory.hpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,10 @@ struct ur_mem_handle_t_ {
393393
// We should wait on this event prior to migrating memory across allocations
394394
// in this ur_mem_handle_t_
395395
ur_event_handle_t LastEventWritingToMemObj{nullptr};
396+
// Since the event may not contain device info (if using interop, which
397+
// doesn't take a queue) we should use this member var to keep track of which
398+
// device has most recent view of data
399+
ur_device_handle_t LastDeviceWritingToMemObj{nullptr};
396400

397401
// Enumerates all possible types of accesses.
398402
enum access_mode_t { unknown, read_write, read_only, write_only };
@@ -487,18 +491,23 @@ struct ur_mem_handle_t_ {
487491

488492
uint32_t getReferenceCount() const noexcept { return RefCount; }
489493

490-
void setLastEventWritingToMemObj(ur_event_handle_t NewEvent) {
494+
void setLastEventWritingToMemObj(ur_event_handle_t NewEvent,
495+
ur_device_handle_t RecentDevice) {
491496
assert(NewEvent && "Invalid event!");
492497
// This entry point should only ever be called when using multi device ctx
493498
assert(Context->Devices.size() > 1);
499+
urEventRetain(NewEvent);
500+
urDeviceRetain(RecentDevice);
494501
if (LastEventWritingToMemObj != nullptr) {
495502
urEventRelease(LastEventWritingToMemObj);
496503
}
497-
urEventRetain(NewEvent);
504+
if (LastDeviceWritingToMemObj != nullptr) {
505+
urDeviceRelease(LastDeviceWritingToMemObj);
506+
}
498507
LastEventWritingToMemObj = NewEvent;
499508
for (const auto &Device : Context->getDevices()) {
500509
HaveMigratedToDeviceSinceLastWrite[Device->getIndex()] =
501-
Device == NewEvent->getDevice();
510+
Device == RecentDevice;
502511
}
503512
}
504513
};

0 commit comments

Comments
 (0)