Skip to content

Commit 3eef2fe

Browse files
author
Hugh Delaney
committed
Remove dependency analysis for ur_mem_handle_ts
It is not the responsibility of a UR adapter to build a DAG of commands for each UR mem handle t. This simplifies code a lot.
1 parent 7345640 commit 3eef2fe

File tree

6 files changed

+102
-425
lines changed

6 files changed

+102
-425
lines changed

source/adapters/cuda/enqueue.cpp

Lines changed: 31 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -414,36 +414,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
414414
UR_ASSERT(workDim > 0, UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
415415
UR_ASSERT(workDim < 4, UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
416416

417-
std::vector<ur_event_handle_t> MemMigrationEvents;
418-
std::vector<std::pair<ur_mem_handle_t, ur_lock>> MemMigrationLocks;
419-
420-
// phEventWaitList only contains events that are handed to UR by the SYCL
421-
// runtime. However since UR handles memory dependencies within a context
422-
// we may need to add more events to our dependent events list if the UR
423-
// context contains multiple devices
424-
if (hQueue->getContext()->Devices.size() > 1) {
425-
MemMigrationLocks.reserve(hKernel->Args.MemObjArgs.size());
426-
for (auto &MemArg : hKernel->Args.MemObjArgs) {
427-
bool PushBack = false;
428-
if (auto MemDepEvent = MemArg.Mem->LastEventWritingToMemObj;
429-
MemDepEvent && !listContainsElem(numEventsInWaitList, phEventWaitList,
430-
MemDepEvent)) {
431-
MemMigrationEvents.push_back(MemDepEvent);
432-
PushBack = true;
433-
}
434-
if ((MemArg.AccessFlags &
435-
(UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) ||
436-
PushBack) {
437-
if (std::find_if(MemMigrationLocks.begin(), MemMigrationLocks.end(),
438-
[MemArg](auto &Lock) {
439-
return Lock.first == MemArg.Mem;
440-
}) == MemMigrationLocks.end())
441-
MemMigrationLocks.emplace_back(
442-
std::pair{MemArg.Mem, ur_lock{MemArg.Mem->MemoryMigrationMutex}});
443-
}
444-
}
445-
}
446-
447417
// Early exit for zero size kernel
448418
if (*pGlobalWorkSize == 0) {
449419
return urEnqueueEventsWaitWithBarrier(hQueue, numEventsInWaitList,
@@ -481,14 +451,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
481451

482452
// For memory migration across devices in the same context
483453
if (hQueue->getContext()->Devices.size() > 1) {
484-
if (MemMigrationEvents.size()) {
485-
UR_CHECK_ERROR(enqueueEventsWait(hQueue, CuStream,
486-
MemMigrationEvents.size(),
487-
MemMigrationEvents.data()));
488-
}
489454
for (auto &MemArg : hKernel->Args.MemObjArgs) {
490455
enqueueMigrateMemoryToDeviceIfNeeded(MemArg.Mem, hQueue->getDevice(),
491456
CuStream);
457+
if (MemArg.AccessFlags &
458+
(UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
459+
MemArg.Mem->setLastQueueWritingToMemObj(hQueue);
460+
}
492461
}
493462
}
494463

@@ -499,20 +468,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
499468
UR_CHECK_ERROR(RetImplEvent->start());
500469
}
501470

502-
// Once event has been started we can unlock MemoryMigrationMutex
503-
if (phEvent && hQueue->getContext()->Devices.size() > 1) {
504-
for (auto &MemArg : hKernel->Args.MemObjArgs) {
505-
// Telling the ur_mem_handle_t that it will need to wait on this kernel
506-
// if it has been written to
507-
if (MemArg.AccessFlags &
508-
(UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
509-
MemArg.Mem->setLastEventWritingToMemObj(RetImplEvent.get());
510-
}
511-
}
512-
// We can release the MemoryMigrationMutexes now
513-
MemMigrationLocks.clear();
514-
}
515-
516471
auto &ArgIndices = hKernel->getArgIndices();
517472
UR_CHECK_ERROR(cuLaunchKernel(
518473
CuFunc, BlocksPerGrid[0], BlocksPerGrid[1], BlocksPerGrid[2],
@@ -605,36 +560,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
605560
}
606561
}
607562

608-
std::vector<ur_event_handle_t> MemMigrationEvents;
609-
std::vector<std::pair<ur_mem_handle_t, ur_lock>> MemMigrationLocks;
610-
611-
// phEventWaitList only contains events that are handed to UR by the SYCL
612-
// runtime. However since UR handles memory dependencies within a context
613-
// we may need to add more events to our dependent events list if the UR
614-
// context contains multiple devices
615-
if (hQueue->getContext()->Devices.size() > 1) {
616-
MemMigrationLocks.reserve(hKernel->Args.MemObjArgs.size());
617-
for (auto &MemArg : hKernel->Args.MemObjArgs) {
618-
bool PushBack = false;
619-
if (auto MemDepEvent = MemArg.Mem->LastEventWritingToMemObj;
620-
MemDepEvent && !listContainsElem(numEventsInWaitList, phEventWaitList,
621-
MemDepEvent)) {
622-
MemMigrationEvents.push_back(MemDepEvent);
623-
PushBack = true;
624-
}
625-
if ((MemArg.AccessFlags &
626-
(UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) ||
627-
PushBack) {
628-
if (std::find_if(MemMigrationLocks.begin(), MemMigrationLocks.end(),
629-
[MemArg](auto &Lock) {
630-
return Lock.first == MemArg.Mem;
631-
}) == MemMigrationLocks.end())
632-
MemMigrationLocks.emplace_back(
633-
std::pair{MemArg.Mem, ur_lock{MemArg.Mem->MemoryMigrationMutex}});
634-
}
635-
}
636-
}
637-
638563
// Early exit for zero size kernel
639564
if (*pGlobalWorkSize == 0) {
640565
return urEnqueueEventsWaitWithBarrier(hQueue, numEventsInWaitList,
@@ -672,14 +597,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
672597

673598
// For memory migration across devices in the same context
674599
if (hQueue->getContext()->Devices.size() > 1) {
675-
if (MemMigrationEvents.size()) {
676-
UR_CHECK_ERROR(enqueueEventsWait(hQueue, CuStream,
677-
MemMigrationEvents.size(),
678-
MemMigrationEvents.data()));
679-
}
680600
for (auto &MemArg : hKernel->Args.MemObjArgs) {
681601
enqueueMigrateMemoryToDeviceIfNeeded(MemArg.Mem, hQueue->getDevice(),
682602
CuStream);
603+
if (MemArg.AccessFlags &
604+
(UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
605+
MemArg.Mem->setLastQueueWritingToMemObj(hQueue);
606+
}
683607
}
684608
}
685609

@@ -690,20 +614,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
690614
UR_CHECK_ERROR(RetImplEvent->start());
691615
}
692616

693-
// Once event has been started we can unlock MemoryMigrationMutex
694-
if (phEvent && hQueue->getContext()->Devices.size() > 1) {
695-
for (auto &MemArg : hKernel->Args.MemObjArgs) {
696-
// Telling the ur_mem_handle_t that it will need to wait on this kernel
697-
// if it has been written to
698-
if (MemArg.AccessFlags &
699-
(UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
700-
MemArg.Mem->setLastEventWritingToMemObj(RetImplEvent.get());
701-
}
702-
}
703-
// We can release the MemoryMigrationMutexes now
704-
MemMigrationLocks.clear();
705-
}
706-
707617
auto &ArgIndices = hKernel->getArgIndices();
708618

709619
CUlaunchConfig launch_config;
@@ -824,28 +734,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
824734
ur_event_handle_t *phEvent) {
825735
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
826736

827-
ur_lock MemoryMigrationLock{hBuffer->MemoryMigrationMutex};
828-
auto Device = hQueue->getDevice();
829-
ScopedContext Active(Device);
830-
CUstream Stream = hQueue->getNextTransferStream();
831-
832737
try {
833738
// Note that this entry point may be called on a queue that may not be the
834739
// last queue to write to the MemBuffer, meaning we must perform the copy
835740
// from a different device
836-
if (hBuffer->LastEventWritingToMemObj &&
837-
hBuffer->LastEventWritingToMemObj->getQueue()->getDevice() !=
838-
hQueue->getDevice()) {
839-
hQueue = hBuffer->LastEventWritingToMemObj->getQueue();
840-
Device = hQueue->getDevice();
841-
ScopedContext Active(Device);
842-
Stream = CUstream{0}; // Default stream for different device
843-
// We may have to wait for an event on another queue if it is the last
844-
// event writing to mem obj
845-
UR_CHECK_ERROR(enqueueEventsWait(hQueue, Stream, 1,
846-
&hBuffer->LastEventWritingToMemObj));
741+
if (hBuffer->LastQueueWritingToMemObj &&
742+
hBuffer->LastQueueWritingToMemObj->getDevice() != hQueue->getDevice()) {
743+
hQueue = hBuffer->LastQueueWritingToMemObj;
847744
}
848745

746+
auto Device = hQueue->getDevice();
747+
ScopedContext Active(Device);
748+
CUstream Stream = hQueue->getNextTransferStream();
749+
849750
UR_CHECK_ERROR(enqueueEventsWait(hQueue, Stream, numEventsInWaitList,
850751
phEventWaitList));
851752

@@ -890,7 +791,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
890791
CUdeviceptr DevPtr =
891792
std::get<BufferMem>(hBuffer->Mem).getPtr(hQueue->getDevice());
892793
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
893-
ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex};
794+
hBuffer->setLastQueueWritingToMemObj(hQueue);
894795

895796
try {
896797
ScopedContext Active(hQueue->getDevice());
@@ -920,7 +821,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
920821

921822
if (phEvent) {
922823
*phEvent = RetImplEvent.release();
923-
hBuffer->setLastEventWritingToMemObj(*phEvent);
924824
}
925825
} catch (ur_result_t Err) {
926826
return Err;
@@ -1060,7 +960,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
1060960
UR_ASSERT(size + offset <= std::get<BufferMem>(hBuffer->Mem).getSize(),
1061961
UR_RESULT_ERROR_INVALID_SIZE);
1062962
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
1063-
ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex};
963+
hBuffer->setLastQueueWritingToMemObj(hQueue);
1064964

1065965
try {
1066966
ScopedContext Active(hQueue->getDevice());
@@ -1107,7 +1007,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
11071007
if (phEvent) {
11081008
UR_CHECK_ERROR(RetImplEvent->record());
11091009
*phEvent = RetImplEvent.release();
1110-
hBuffer->setLastEventWritingToMemObj(*phEvent);
11111010
}
11121011
} catch (ur_result_t Err) {
11131012
return Err;
@@ -1215,28 +1114,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
12151114

12161115
UR_ASSERT(hImage->isImage(), UR_RESULT_ERROR_INVALID_MEM_OBJECT);
12171116

1218-
ur_lock MemoryMigrationLock{hImage->MemoryMigrationMutex};
1219-
auto Device = hQueue->getDevice();
1220-
CUstream Stream = hQueue->getNextTransferStream();
1221-
12221117
try {
12231118
// Note that this entry point may be called on a queue that may not be the
12241119
// last queue to write to the Image, meaning we must perform the copy
12251120
// from a different device
1226-
if (hImage->LastEventWritingToMemObj &&
1227-
hImage->LastEventWritingToMemObj->getQueue()->getDevice() !=
1228-
hQueue->getDevice()) {
1229-
hQueue = hImage->LastEventWritingToMemObj->getQueue();
1230-
Device = hQueue->getDevice();
1231-
ScopedContext Active(Device);
1232-
Stream = CUstream{0}; // Default stream for different device
1233-
// We may have to wait for an event on another queue if it is the last
1234-
// event writing to mem obj
1235-
UR_CHECK_ERROR(enqueueEventsWait(hQueue, Stream, 1,
1236-
&hImage->LastEventWritingToMemObj));
1121+
if (hImage->LastQueueWritingToMemObj &&
1122+
hImage->LastQueueWritingToMemObj->getDevice() != hQueue->getDevice()) {
1123+
hQueue = hImage->LastQueueWritingToMemObj;
12371124
}
12381125

1126+
auto Device = hQueue->getDevice();
12391127
ScopedContext Active(Device);
1128+
CUstream Stream = hQueue->getNextTransferStream();
1129+
12401130
UR_CHECK_ERROR(enqueueEventsWait(hQueue, Stream, numEventsInWaitList,
12411131
phEventWaitList));
12421132

@@ -1839,28 +1729,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
18391729
UR_ASSERT(offset + size <= std::get<BufferMem>(hBuffer->Mem).Size,
18401730
UR_RESULT_ERROR_INVALID_SIZE);
18411731
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
1842-
ur_lock MemoryMigrationLock{hBuffer->MemoryMigrationMutex};
1843-
auto Device = hQueue->getDevice();
1844-
ScopedContext Active(Device);
1845-
CUstream Stream = hQueue->getNextTransferStream();
18461732

18471733
try {
18481734
// Note that this entry point may be called on a queue that may not be the
18491735
// last queue to write to the MemBuffer, meaning we must perform the copy
18501736
// from a different device
1851-
if (hBuffer->LastEventWritingToMemObj &&
1852-
hBuffer->LastEventWritingToMemObj->getQueue()->getDevice() !=
1853-
hQueue->getDevice()) {
1854-
hQueue = hBuffer->LastEventWritingToMemObj->getQueue();
1855-
Device = hQueue->getDevice();
1856-
ScopedContext Active(Device);
1857-
Stream = CUstream{0}; // Default stream for different device
1858-
// We may have to wait for an event on another queue if it is the last
1859-
// event writing to mem obj
1860-
UR_CHECK_ERROR(enqueueEventsWait(hQueue, Stream, 1,
1861-
&hBuffer->LastEventWritingToMemObj));
1737+
if (hBuffer->LastQueueWritingToMemObj &&
1738+
hBuffer->LastQueueWritingToMemObj->getDevice() != hQueue->getDevice()) {
1739+
hQueue = hBuffer->LastQueueWritingToMemObj;
18621740
}
18631741

1742+
auto Device = hQueue->getDevice();
1743+
ScopedContext Active(Device);
1744+
CUstream Stream = hQueue->getNextTransferStream();
1745+
18641746
UR_CHECK_ERROR(enqueueEventsWait(hQueue, Stream, numEventsInWaitList,
18651747
phEventWaitList));
18661748

@@ -1905,7 +1787,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
19051787
CUdeviceptr DevPtr =
19061788
std::get<BufferMem>(hBuffer->Mem).getPtr(hQueue->getDevice());
19071789
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
1908-
ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex};
1790+
hBuffer->setLastQueueWritingToMemObj(hQueue);
19091791

19101792
try {
19111793
ScopedContext Active(hQueue->getDevice());
@@ -1933,7 +1815,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
19331815

19341816
if (phEvent) {
19351817
*phEvent = RetImplEvent.release();
1936-
hBuffer->setLastEventWritingToMemObj(*phEvent);
19371818
}
19381819
} catch (ur_result_t Err) {
19391820
return Err;

source/adapters/cuda/memory.cpp

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -506,18 +506,17 @@ ur_result_t enqueueMigrateBufferToDevice(ur_mem_handle_t Mem,
506506
ur_device_handle_t hDevice,
507507
CUstream Stream) {
508508
auto &Buffer = std::get<BufferMem>(Mem->Mem);
509-
if (Mem->LastEventWritingToMemObj == nullptr) {
509+
if (Mem->LastQueueWritingToMemObj == nullptr) {
510510
// Device allocation being initialized from host for the first time
511511
if (Buffer.HostPtr) {
512512
UR_CHECK_ERROR(cuMemcpyHtoDAsync(Buffer.getPtr(hDevice), Buffer.HostPtr,
513513
Buffer.Size, Stream));
514514
}
515-
} else if (Mem->LastEventWritingToMemObj->getQueue()->getDevice() !=
516-
hDevice) {
515+
} else if (Mem->LastQueueWritingToMemObj->getDevice() != hDevice) {
517516
UR_CHECK_ERROR(cuMemcpyDtoDAsync(
518517
Buffer.getPtr(hDevice),
519-
Buffer.getPtr(Mem->LastEventWritingToMemObj->getQueue()->getDevice()),
520-
Buffer.Size, Stream));
518+
Buffer.getPtr(Mem->LastQueueWritingToMemObj->getDevice()), Buffer.Size,
519+
Stream));
521520
}
522521
return UR_RESULT_SUCCESS;
523522
}
@@ -555,7 +554,7 @@ ur_result_t enqueueMigrateImageToDevice(ur_mem_handle_t Mem,
555554
CpyDesc3D.Depth = Image.ImageDesc.depth;
556555
}
557556

558-
if (Mem->LastEventWritingToMemObj == nullptr) {
557+
if (Mem->LastQueueWritingToMemObj == nullptr) {
559558
if (Image.HostPtr) {
560559
if (Image.ImageDesc.type == UR_MEM_TYPE_IMAGE1D) {
561560
UR_CHECK_ERROR(cuMemcpyHtoAAsync(ImageArray, 0, Image.HostPtr,
@@ -570,29 +569,26 @@ ur_result_t enqueueMigrateImageToDevice(ur_mem_handle_t Mem,
570569
UR_CHECK_ERROR(cuMemcpy3DAsync(&CpyDesc3D, Stream));
571570
}
572571
}
573-
} else if (Mem->LastEventWritingToMemObj->getQueue()->getDevice() !=
574-
hDevice) {
572+
} else if (Mem->LastQueueWritingToMemObj->getDevice() != hDevice) {
575573
if (Image.ImageDesc.type == UR_MEM_TYPE_IMAGE1D) {
576-
// Blocking wait needed since we need to sync LastEventWritingToMemObj's
577-
// queue, as well as the current queue with LastEventWritingToMemObj
578-
UR_CHECK_ERROR(urEventWait(1, &Mem->LastEventWritingToMemObj));
574+
// Blocking wait needed
575+
UR_CHECK_ERROR(urQueueFinish(Mem->LastQueueWritingToMemObj));
579576
// FIXME: 1D memcpy from DtoD going through the host.
580577
UR_CHECK_ERROR(cuMemcpyAtoH(
581578
Image.HostPtr,
582-
Image.getArray(
583-
Mem->LastEventWritingToMemObj->getQueue()->getDevice()),
579+
Image.getArray(Mem->LastQueueWritingToMemObj->getDevice()),
584580
0 /*srcOffset*/, ImageSizeBytes));
585581
UR_CHECK_ERROR(
586582
cuMemcpyHtoA(ImageArray, 0, Image.HostPtr, ImageSizeBytes));
587583
} else if (Image.ImageDesc.type == UR_MEM_TYPE_IMAGE2D) {
588584
CpyDesc2D.srcMemoryType = CUmemorytype_enum::CU_MEMORYTYPE_DEVICE;
589-
CpyDesc2D.srcArray = Image.getArray(
590-
Mem->LastEventWritingToMemObj->getQueue()->getDevice());
585+
CpyDesc2D.srcArray =
586+
Image.getArray(Mem->LastQueueWritingToMemObj->getDevice());
591587
UR_CHECK_ERROR(cuMemcpy2DAsync(&CpyDesc2D, Stream));
592588
} else if (Image.ImageDesc.type == UR_MEM_TYPE_IMAGE3D) {
593589
CpyDesc3D.srcMemoryType = CUmemorytype_enum::CU_MEMORYTYPE_DEVICE;
594-
CpyDesc3D.srcArray = Image.getArray(
595-
Mem->LastEventWritingToMemObj->getQueue()->getDevice());
590+
CpyDesc3D.srcArray =
591+
Image.getArray(Mem->LastQueueWritingToMemObj->getDevice());
596592
UR_CHECK_ERROR(cuMemcpy3DAsync(&CpyDesc3D, Stream));
597593
}
598594
}

0 commit comments

Comments
 (0)