Skip to content

Commit a2eb1e4

Browse files
author
Hugh Delaney
committed
Remove LastDeviceWritingToMemObj
LastEventWritingToMemObj is never an interop event, so always has an associated queue. So using an extra LastDeviceWritingToMemObj is not necessary.
1 parent 15a2787 commit a2eb1e4

File tree

3 files changed

+28
-28
lines changed

3 files changed

+28
-28
lines changed

source/adapters/hip/enqueue.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
218218
// last queue to write to the MemBuffer, meaning we must perform the copy
219219
// from a different device
220220
if (hBuffer->LastEventWritingToMemObj &&
221-
hBuffer->LastDeviceWritingToMemObj != hQueue->getDevice()) {
222-
Device = hBuffer->LastDeviceWritingToMemObj;
221+
hBuffer->LastEventWritingToMemObj->getQueue()->getDevice() !=
222+
hQueue->getDevice()) {
223+
Device = hBuffer->LastEventWritingToMemObj->getQueue()->getDevice();
223224
ScopedContext Active(Device);
224225
HIPStream = hipStream_t{0}; // Default stream for different device
225226
// We may have to wait for an event on another queue if it is the last
@@ -367,8 +368,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
367368
// if it has been written to
368369
if (phEvent && (MemArg.AccessFlags &
369370
(UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY))) {
370-
MemArg.Mem->setLastEventWritingToMemObj(RetImplEvent.get(),
371-
hQueue->getDevice());
371+
MemArg.Mem->setLastEventWritingToMemObj(RetImplEvent.get());
372372
}
373373
}
374374
// We can release the MemoryMigrationMutexes now
@@ -585,8 +585,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
585585
// last queue to write to the MemBuffer, meaning we must perform the copy
586586
// from a different device
587587
if (hBuffer->LastEventWritingToMemObj &&
588-
hBuffer->LastDeviceWritingToMemObj != hQueue->getDevice()) {
589-
Device = hBuffer->LastDeviceWritingToMemObj;
588+
hBuffer->LastEventWritingToMemObj->getQueue()->getDevice() !=
589+
hQueue->getDevice()) {
590+
Device = hBuffer->LastEventWritingToMemObj->getQueue()->getDevice();
590591
ScopedContext Active(Device);
591592
HIPStream = hipStream_t{0}; // Default stream for different device
592593
// We may have to wait for an event on another queue if it is the last
@@ -1018,8 +1019,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
10181019
// last queue to write to the MemBuffer, meaning we must perform the copy
10191020
// from a different device
10201021
if (hImage->LastEventWritingToMemObj &&
1021-
hImage->LastDeviceWritingToMemObj != hQueue->getDevice()) {
1022-
Device = hImage->LastDeviceWritingToMemObj;
1022+
hImage->LastEventWritingToMemObj->getQueue()->getDevice() !=
1023+
hQueue->getDevice()) {
1024+
Device = hImage->LastEventWritingToMemObj->getQueue()->getDevice();
10231025
ScopedContext Active(Device);
10241026
HIPStream = hipStream_t{0}; // Default stream for different device
10251027
// We may have to wait for an event on another queue if it is the last

source/adapters/hip/memory.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -525,10 +525,12 @@ inline ur_result_t migrateBufferToDevice(ur_mem_handle_t Mem,
525525
UR_CHECK_ERROR(
526526
hipMemcpyHtoD(Buffer.getPtr(hDevice), Buffer.HostPtr, Buffer.Size));
527527
}
528-
} else if (Mem->LastDeviceWritingToMemObj != hDevice) {
529-
UR_CHECK_ERROR(hipMemcpyDtoD(Buffer.getPtr(hDevice),
530-
Buffer.getPtr(Mem->LastDeviceWritingToMemObj),
531-
Buffer.Size));
528+
} else if (Mem->LastEventWritingToMemObj->getQueue()->getDevice() !=
529+
hDevice) {
530+
UR_CHECK_ERROR(hipMemcpyDtoD(
531+
Buffer.getPtr(hDevice),
532+
Buffer.getPtr(Mem->LastEventWritingToMemObj->getQueue()->getDevice()),
533+
Buffer.Size));
532534
}
533535
return UR_RESULT_SUCCESS;
534536
}
@@ -576,19 +578,24 @@ inline ur_result_t migrateImageToDevice(ur_mem_handle_t Mem,
576578
CpyDesc3D.srcHost = Image.HostPtr;
577579
UR_CHECK_ERROR(hipDrvMemcpy3D(&CpyDesc3D));
578580
}
579-
} else if (Mem->LastDeviceWritingToMemObj != hDevice) {
581+
} else if (Mem->LastEventWritingToMemObj->getQueue()->getDevice() !=
582+
hDevice) {
580583
if (Image.ImageDesc.type == UR_MEM_TYPE_IMAGE1D) {
581584
// FIXME: 1D memcpy from DtoD going through the host.
582585
UR_CHECK_ERROR(hipMemcpyAtoH(
583-
Image.HostPtr, Image.getArray(Mem->LastDeviceWritingToMemObj),
586+
Image.HostPtr,
587+
Image.getArray(
588+
Mem->LastEventWritingToMemObj->getQueue()->getDevice()),
584589
0 /*srcOffset*/, ImageSizeBytes));
585590
UR_CHECK_ERROR(
586591
hipMemcpyHtoA(ImageArray, 0, Image.HostPtr, ImageSizeBytes));
587592
} else if (Image.ImageDesc.type == UR_MEM_TYPE_IMAGE2D) {
588-
CpyDesc2D.srcArray = Image.getArray(Mem->LastDeviceWritingToMemObj);
593+
CpyDesc2D.srcArray = Image.getArray(
594+
Mem->LastEventWritingToMemObj->getQueue()->getDevice());
589595
UR_CHECK_ERROR(hipMemcpyParam2D(&CpyDesc2D));
590596
} else if (Image.ImageDesc.type == UR_MEM_TYPE_IMAGE3D) {
591-
CpyDesc3D.srcArray = Image.getArray(Mem->LastDeviceWritingToMemObj);
597+
CpyDesc3D.srcArray = Image.getArray(
598+
Mem->LastEventWritingToMemObj->getQueue()->getDevice());
592599
UR_CHECK_ERROR(hipDrvMemcpy3D(&CpyDesc3D));
593600
}
594601
}

source/adapters/hip/memory.hpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -393,10 +393,6 @@ struct ur_mem_handle_t_ {
393393
// We should wait on this event prior to migrating memory across allocations
394394
// in this ur_mem_handle_t_
395395
ur_event_handle_t LastEventWritingToMemObj{nullptr};
396-
// Since the event may not contain device info (if using interop, which
397-
// doesn't take a queue) we should use this member var to keep track of which
398-
// device has most recent view of data
399-
ur_device_handle_t LastDeviceWritingToMemObj{nullptr};
400396

401397
// Enumerates all possible types of accesses.
402398
enum access_mode_t { unknown, read_write, read_only, write_only };
@@ -491,23 +487,18 @@ struct ur_mem_handle_t_ {
491487

492488
uint32_t getReferenceCount() const noexcept { return RefCount; }
493489

494-
void setLastEventWritingToMemObj(ur_event_handle_t NewEvent,
495-
ur_device_handle_t RecentDevice) {
490+
void setLastEventWritingToMemObj(ur_event_handle_t NewEvent) {
496491
assert(NewEvent && "Invalid event!");
497492
// This entry point should only ever be called when using multi device ctx
498493
assert(Context->Devices.size() > 1);
499-
urEventRetain(NewEvent);
500-
urDeviceRetain(RecentDevice);
501494
if (LastEventWritingToMemObj != nullptr) {
502495
urEventRelease(LastEventWritingToMemObj);
503496
}
504-
if (LastDeviceWritingToMemObj != nullptr) {
505-
urDeviceRelease(LastDeviceWritingToMemObj);
506-
}
497+
urEventRetain(NewEvent);
507498
LastEventWritingToMemObj = NewEvent;
508499
for (const auto &Device : Context->getDevices()) {
509500
HaveMigratedToDeviceSinceLastWrite[Device->getIndex()] =
510-
Device == RecentDevice;
501+
Device == NewEvent->getQueue()->getDevice();
511502
}
512503
}
513504
};

0 commit comments

Comments
 (0)