@@ -414,36 +414,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
414414 UR_ASSERT (workDim > 0 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
415415 UR_ASSERT (workDim < 4 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
416416
417- std::vector<ur_event_handle_t > MemMigrationEvents;
418- std::vector<std::pair<ur_mem_handle_t , ur_lock>> MemMigrationLocks;
419-
420- // phEventWaitList only contains events that are handed to UR by the SYCL
421- // runtime. However since UR handles memory dependencies within a context
422- // we may need to add more events to our dependent events list if the UR
423- // context contains multiple devices
424- if (hQueue->getContext ()->Devices .size () > 1 ) {
425- MemMigrationLocks.reserve (hKernel->Args .MemObjArgs .size ());
426- for (auto &MemArg : hKernel->Args .MemObjArgs ) {
427- bool PushBack = false ;
428- if (auto MemDepEvent = MemArg.Mem ->LastEventWritingToMemObj ;
429- MemDepEvent && !listContainsElem (numEventsInWaitList, phEventWaitList,
430- MemDepEvent)) {
431- MemMigrationEvents.push_back (MemDepEvent);
432- PushBack = true ;
433- }
434- if ((MemArg.AccessFlags &
435- (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) ||
436- PushBack) {
437- if (std::find_if (MemMigrationLocks.begin (), MemMigrationLocks.end (),
438- [MemArg](auto &Lock) {
439- return Lock.first == MemArg.Mem ;
440- }) == MemMigrationLocks.end ())
441- MemMigrationLocks.emplace_back (
442- std::pair{MemArg.Mem , ur_lock{MemArg.Mem ->MemoryMigrationMutex }});
443- }
444- }
445- }
446-
447417 // Early exit for zero size kernel
448418 if (*pGlobalWorkSize == 0 ) {
449419 return urEnqueueEventsWaitWithBarrier (hQueue, numEventsInWaitList,
@@ -481,14 +451,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
481451
482452 // For memory migration across devices in the same context
483453 if (hQueue->getContext ()->Devices .size () > 1 ) {
484- if (MemMigrationEvents.size ()) {
485- UR_CHECK_ERROR (enqueueEventsWait (hQueue, CuStream,
486- MemMigrationEvents.size (),
487- MemMigrationEvents.data ()));
488- }
489454 for (auto &MemArg : hKernel->Args .MemObjArgs ) {
490455 enqueueMigrateMemoryToDeviceIfNeeded (MemArg.Mem , hQueue->getDevice (),
491456 CuStream);
457+ if (MemArg.AccessFlags &
458+ (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
459+ MemArg.Mem ->setLastQueueWritingToMemObj (hQueue);
460+ }
492461 }
493462 }
494463
@@ -499,20 +468,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
499468 UR_CHECK_ERROR (RetImplEvent->start ());
500469 }
501470
502- // Once event has been started we can unlock MemoryMigrationMutex
503- if (phEvent && hQueue->getContext ()->Devices .size () > 1 ) {
504- for (auto &MemArg : hKernel->Args .MemObjArgs ) {
505- // Telling the ur_mem_handle_t that it will need to wait on this kernel
506- // if it has been written to
507- if (MemArg.AccessFlags &
508- (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
509- MemArg.Mem ->setLastEventWritingToMemObj (RetImplEvent.get ());
510- }
511- }
512- // We can release the MemoryMigrationMutexes now
513- MemMigrationLocks.clear ();
514- }
515-
516471 auto &ArgIndices = hKernel->getArgIndices ();
517472 UR_CHECK_ERROR (cuLaunchKernel (
518473 CuFunc, BlocksPerGrid[0 ], BlocksPerGrid[1 ], BlocksPerGrid[2 ],
@@ -605,36 +560,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
605560 }
606561 }
607562
608- std::vector<ur_event_handle_t > MemMigrationEvents;
609- std::vector<std::pair<ur_mem_handle_t , ur_lock>> MemMigrationLocks;
610-
611- // phEventWaitList only contains events that are handed to UR by the SYCL
612- // runtime. However since UR handles memory dependencies within a context
613- // we may need to add more events to our dependent events list if the UR
614- // context contains multiple devices
615- if (hQueue->getContext ()->Devices .size () > 1 ) {
616- MemMigrationLocks.reserve (hKernel->Args .MemObjArgs .size ());
617- for (auto &MemArg : hKernel->Args .MemObjArgs ) {
618- bool PushBack = false ;
619- if (auto MemDepEvent = MemArg.Mem ->LastEventWritingToMemObj ;
620- MemDepEvent && !listContainsElem (numEventsInWaitList, phEventWaitList,
621- MemDepEvent)) {
622- MemMigrationEvents.push_back (MemDepEvent);
623- PushBack = true ;
624- }
625- if ((MemArg.AccessFlags &
626- (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) ||
627- PushBack) {
628- if (std::find_if (MemMigrationLocks.begin (), MemMigrationLocks.end (),
629- [MemArg](auto &Lock) {
630- return Lock.first == MemArg.Mem ;
631- }) == MemMigrationLocks.end ())
632- MemMigrationLocks.emplace_back (
633- std::pair{MemArg.Mem , ur_lock{MemArg.Mem ->MemoryMigrationMutex }});
634- }
635- }
636- }
637-
638563 // Early exit for zero size kernel
639564 if (*pGlobalWorkSize == 0 ) {
640565 return urEnqueueEventsWaitWithBarrier (hQueue, numEventsInWaitList,
@@ -672,14 +597,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
672597
673598 // For memory migration across devices in the same context
674599 if (hQueue->getContext ()->Devices .size () > 1 ) {
675- if (MemMigrationEvents.size ()) {
676- UR_CHECK_ERROR (enqueueEventsWait (hQueue, CuStream,
677- MemMigrationEvents.size (),
678- MemMigrationEvents.data ()));
679- }
680600 for (auto &MemArg : hKernel->Args .MemObjArgs ) {
681601 enqueueMigrateMemoryToDeviceIfNeeded (MemArg.Mem , hQueue->getDevice (),
682602 CuStream);
603+ if (MemArg.AccessFlags &
604+ (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
605+ MemArg.Mem ->setLastQueueWritingToMemObj (hQueue);
606+ }
683607 }
684608 }
685609
@@ -690,20 +614,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
690614 UR_CHECK_ERROR (RetImplEvent->start ());
691615 }
692616
693- // Once event has been started we can unlock MemoryMigrationMutex
694- if (phEvent && hQueue->getContext ()->Devices .size () > 1 ) {
695- for (auto &MemArg : hKernel->Args .MemObjArgs ) {
696- // Telling the ur_mem_handle_t that it will need to wait on this kernel
697- // if it has been written to
698- if (MemArg.AccessFlags &
699- (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
700- MemArg.Mem ->setLastEventWritingToMemObj (RetImplEvent.get ());
701- }
702- }
703- // We can release the MemoryMigrationMutexes now
704- MemMigrationLocks.clear ();
705- }
706-
707617 auto &ArgIndices = hKernel->getArgIndices ();
708618
709619 CUlaunchConfig launch_config;
@@ -824,28 +734,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
824734 ur_event_handle_t *phEvent) {
825735 std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
826736
827- ur_lock MemoryMigrationLock{hBuffer->MemoryMigrationMutex };
828- auto Device = hQueue->getDevice ();
829- ScopedContext Active (Device);
830- CUstream Stream = hQueue->getNextTransferStream ();
831-
832737 try {
833738 // Note that this entry point may be called on a queue that may not be the
834739 // last queue to write to the MemBuffer, meaning we must perform the copy
835740 // from a different device
836- if (hBuffer->LastEventWritingToMemObj &&
837- hBuffer->LastEventWritingToMemObj ->getQueue ()->getDevice () !=
838- hQueue->getDevice ()) {
839- hQueue = hBuffer->LastEventWritingToMemObj ->getQueue ();
840- Device = hQueue->getDevice ();
841- ScopedContext Active (Device);
842- Stream = CUstream{0 }; // Default stream for different device
843- // We may have to wait for an event on another queue if it is the last
844- // event writing to mem obj
845- UR_CHECK_ERROR (enqueueEventsWait (hQueue, Stream, 1 ,
846- &hBuffer->LastEventWritingToMemObj ));
741+ if (hBuffer->LastQueueWritingToMemObj &&
742+ hBuffer->LastQueueWritingToMemObj ->getDevice () != hQueue->getDevice ()) {
743+ hQueue = hBuffer->LastQueueWritingToMemObj ;
847744 }
848745
746+ auto Device = hQueue->getDevice ();
747+ ScopedContext Active (Device);
748+ CUstream Stream = hQueue->getNextTransferStream ();
749+
849750 UR_CHECK_ERROR (enqueueEventsWait (hQueue, Stream, numEventsInWaitList,
850751 phEventWaitList));
851752
@@ -890,7 +791,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
890791 CUdeviceptr DevPtr =
891792 std::get<BufferMem>(hBuffer->Mem ).getPtr (hQueue->getDevice ());
892793 std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
893- ur_lock MemMigrationLock{ hBuffer->MemoryMigrationMutex } ;
794+ hBuffer->setLastQueueWritingToMemObj (hQueue) ;
894795
895796 try {
896797 ScopedContext Active (hQueue->getDevice ());
@@ -920,7 +821,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
920821
921822 if (phEvent) {
922823 *phEvent = RetImplEvent.release ();
923- hBuffer->setLastEventWritingToMemObj (*phEvent);
924824 }
925825 } catch (ur_result_t Err) {
926826 return Err;
@@ -1060,7 +960,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
1060960 UR_ASSERT (size + offset <= std::get<BufferMem>(hBuffer->Mem ).getSize (),
1061961 UR_RESULT_ERROR_INVALID_SIZE);
1062962 std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
1063- ur_lock MemMigrationLock{ hBuffer->MemoryMigrationMutex } ;
963+ hBuffer->setLastQueueWritingToMemObj (hQueue) ;
1064964
1065965 try {
1066966 ScopedContext Active (hQueue->getDevice ());
@@ -1107,7 +1007,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
11071007 if (phEvent) {
11081008 UR_CHECK_ERROR (RetImplEvent->record ());
11091009 *phEvent = RetImplEvent.release ();
1110- hBuffer->setLastEventWritingToMemObj (*phEvent);
11111010 }
11121011 } catch (ur_result_t Err) {
11131012 return Err;
@@ -1215,28 +1114,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
12151114
12161115 UR_ASSERT (hImage->isImage (), UR_RESULT_ERROR_INVALID_MEM_OBJECT);
12171116
1218- ur_lock MemoryMigrationLock{hImage->MemoryMigrationMutex };
1219- auto Device = hQueue->getDevice ();
1220- CUstream Stream = hQueue->getNextTransferStream ();
1221-
12221117 try {
12231118 // Note that this entry point may be called on a queue that may not be the
12241119 // last queue to write to the Image, meaning we must perform the copy
12251120 // from a different device
1226- if (hImage->LastEventWritingToMemObj &&
1227- hImage->LastEventWritingToMemObj ->getQueue ()->getDevice () !=
1228- hQueue->getDevice ()) {
1229- hQueue = hImage->LastEventWritingToMemObj ->getQueue ();
1230- Device = hQueue->getDevice ();
1231- ScopedContext Active (Device);
1232- Stream = CUstream{0 }; // Default stream for different device
1233- // We may have to wait for an event on another queue if it is the last
1234- // event writing to mem obj
1235- UR_CHECK_ERROR (enqueueEventsWait (hQueue, Stream, 1 ,
1236- &hImage->LastEventWritingToMemObj ));
1121+ if (hImage->LastQueueWritingToMemObj &&
1122+ hImage->LastQueueWritingToMemObj ->getDevice () != hQueue->getDevice ()) {
1123+ hQueue = hImage->LastQueueWritingToMemObj ;
12371124 }
12381125
1126+ auto Device = hQueue->getDevice ();
12391127 ScopedContext Active (Device);
1128+ CUstream Stream = hQueue->getNextTransferStream ();
1129+
12401130 UR_CHECK_ERROR (enqueueEventsWait (hQueue, Stream, numEventsInWaitList,
12411131 phEventWaitList));
12421132
@@ -1839,28 +1729,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
18391729 UR_ASSERT (offset + size <= std::get<BufferMem>(hBuffer->Mem ).Size ,
18401730 UR_RESULT_ERROR_INVALID_SIZE);
18411731 std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
1842- ur_lock MemoryMigrationLock{hBuffer->MemoryMigrationMutex };
1843- auto Device = hQueue->getDevice ();
1844- ScopedContext Active (Device);
1845- CUstream Stream = hQueue->getNextTransferStream ();
18461732
18471733 try {
18481734 // Note that this entry point may be called on a queue that may not be the
18491735 // last queue to write to the MemBuffer, meaning we must perform the copy
18501736 // from a different device
1851- if (hBuffer->LastEventWritingToMemObj &&
1852- hBuffer->LastEventWritingToMemObj ->getQueue ()->getDevice () !=
1853- hQueue->getDevice ()) {
1854- hQueue = hBuffer->LastEventWritingToMemObj ->getQueue ();
1855- Device = hQueue->getDevice ();
1856- ScopedContext Active (Device);
1857- Stream = CUstream{0 }; // Default stream for different device
1858- // We may have to wait for an event on another queue if it is the last
1859- // event writing to mem obj
1860- UR_CHECK_ERROR (enqueueEventsWait (hQueue, Stream, 1 ,
1861- &hBuffer->LastEventWritingToMemObj ));
1737+ if (hBuffer->LastQueueWritingToMemObj &&
1738+ hBuffer->LastQueueWritingToMemObj ->getDevice () != hQueue->getDevice ()) {
1739+ hQueue = hBuffer->LastQueueWritingToMemObj ;
18621740 }
18631741
1742+ auto Device = hQueue->getDevice ();
1743+ ScopedContext Active (Device);
1744+ CUstream Stream = hQueue->getNextTransferStream ();
1745+
18641746 UR_CHECK_ERROR (enqueueEventsWait (hQueue, Stream, numEventsInWaitList,
18651747 phEventWaitList));
18661748
@@ -1905,7 +1787,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
19051787 CUdeviceptr DevPtr =
19061788 std::get<BufferMem>(hBuffer->Mem ).getPtr (hQueue->getDevice ());
19071789 std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
1908- ur_lock MemMigrationLock{ hBuffer->MemoryMigrationMutex } ;
1790+ hBuffer->setLastQueueWritingToMemObj (hQueue) ;
19091791
19101792 try {
19111793 ScopedContext Active (hQueue->getDevice ());
@@ -1933,7 +1815,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
19331815
19341816 if (phEvent) {
19351817 *phEvent = RetImplEvent.release ();
1936- hBuffer->setLastEventWritingToMemObj (*phEvent);
19371818 }
19381819 } catch (ur_result_t Err) {
19391820 return Err;
0 commit comments