@@ -414,36 +414,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
414
414
UR_ASSERT (workDim > 0 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
415
415
UR_ASSERT (workDim < 4 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
416
416
417
- std::vector<ur_event_handle_t > MemMigrationEvents;
418
- std::vector<std::pair<ur_mem_handle_t , ur_lock>> MemMigrationLocks;
419
-
420
- // phEventWaitList only contains events that are handed to UR by the SYCL
421
- // runtime. However since UR handles memory dependencies within a context
422
- // we may need to add more events to our dependent events list if the UR
423
- // context contains multiple devices
424
- if (hQueue->getContext ()->Devices .size () > 1 ) {
425
- MemMigrationLocks.reserve (hKernel->Args .MemObjArgs .size ());
426
- for (auto &MemArg : hKernel->Args .MemObjArgs ) {
427
- bool PushBack = false ;
428
- if (auto MemDepEvent = MemArg.Mem ->LastEventWritingToMemObj ;
429
- MemDepEvent && !listContainsElem (numEventsInWaitList, phEventWaitList,
430
- MemDepEvent)) {
431
- MemMigrationEvents.push_back (MemDepEvent);
432
- PushBack = true ;
433
- }
434
- if ((MemArg.AccessFlags &
435
- (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) ||
436
- PushBack) {
437
- if (std::find_if (MemMigrationLocks.begin (), MemMigrationLocks.end (),
438
- [MemArg](auto &Lock) {
439
- return Lock.first == MemArg.Mem ;
440
- }) == MemMigrationLocks.end ())
441
- MemMigrationLocks.emplace_back (
442
- std::pair{MemArg.Mem , ur_lock{MemArg.Mem ->MemoryMigrationMutex }});
443
- }
444
- }
445
- }
446
-
447
417
// Early exit for zero size kernel
448
418
if (*pGlobalWorkSize == 0 ) {
449
419
return urEnqueueEventsWaitWithBarrier (hQueue, numEventsInWaitList,
@@ -481,14 +451,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
481
451
482
452
// For memory migration across devices in the same context
483
453
if (hQueue->getContext ()->Devices .size () > 1 ) {
484
- if (MemMigrationEvents.size ()) {
485
- UR_CHECK_ERROR (enqueueEventsWait (hQueue, CuStream,
486
- MemMigrationEvents.size (),
487
- MemMigrationEvents.data ()));
488
- }
489
454
for (auto &MemArg : hKernel->Args .MemObjArgs ) {
490
455
enqueueMigrateMemoryToDeviceIfNeeded (MemArg.Mem , hQueue->getDevice (),
491
456
CuStream);
457
+ if (MemArg.AccessFlags &
458
+ (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
459
+ MemArg.Mem ->setLastQueueWritingToMemObj (hQueue);
460
+ }
492
461
}
493
462
}
494
463
@@ -499,20 +468,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
499
468
UR_CHECK_ERROR (RetImplEvent->start ());
500
469
}
501
470
502
- // Once event has been started we can unlock MemoryMigrationMutex
503
- if (phEvent && hQueue->getContext ()->Devices .size () > 1 ) {
504
- for (auto &MemArg : hKernel->Args .MemObjArgs ) {
505
- // Telling the ur_mem_handle_t that it will need to wait on this kernel
506
- // if it has been written to
507
- if (MemArg.AccessFlags &
508
- (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
509
- MemArg.Mem ->setLastEventWritingToMemObj (RetImplEvent.get ());
510
- }
511
- }
512
- // We can release the MemoryMigrationMutexes now
513
- MemMigrationLocks.clear ();
514
- }
515
-
516
471
auto &ArgIndices = hKernel->getArgIndices ();
517
472
UR_CHECK_ERROR (cuLaunchKernel (
518
473
CuFunc, BlocksPerGrid[0 ], BlocksPerGrid[1 ], BlocksPerGrid[2 ],
@@ -605,36 +560,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
605
560
}
606
561
}
607
562
608
- std::vector<ur_event_handle_t > MemMigrationEvents;
609
- std::vector<std::pair<ur_mem_handle_t , ur_lock>> MemMigrationLocks;
610
-
611
- // phEventWaitList only contains events that are handed to UR by the SYCL
612
- // runtime. However since UR handles memory dependencies within a context
613
- // we may need to add more events to our dependent events list if the UR
614
- // context contains multiple devices
615
- if (hQueue->getContext ()->Devices .size () > 1 ) {
616
- MemMigrationLocks.reserve (hKernel->Args .MemObjArgs .size ());
617
- for (auto &MemArg : hKernel->Args .MemObjArgs ) {
618
- bool PushBack = false ;
619
- if (auto MemDepEvent = MemArg.Mem ->LastEventWritingToMemObj ;
620
- MemDepEvent && !listContainsElem (numEventsInWaitList, phEventWaitList,
621
- MemDepEvent)) {
622
- MemMigrationEvents.push_back (MemDepEvent);
623
- PushBack = true ;
624
- }
625
- if ((MemArg.AccessFlags &
626
- (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) ||
627
- PushBack) {
628
- if (std::find_if (MemMigrationLocks.begin (), MemMigrationLocks.end (),
629
- [MemArg](auto &Lock) {
630
- return Lock.first == MemArg.Mem ;
631
- }) == MemMigrationLocks.end ())
632
- MemMigrationLocks.emplace_back (
633
- std::pair{MemArg.Mem , ur_lock{MemArg.Mem ->MemoryMigrationMutex }});
634
- }
635
- }
636
- }
637
-
638
563
// Early exit for zero size kernel
639
564
if (*pGlobalWorkSize == 0 ) {
640
565
return urEnqueueEventsWaitWithBarrier (hQueue, numEventsInWaitList,
@@ -672,14 +597,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
672
597
673
598
// For memory migration across devices in the same context
674
599
if (hQueue->getContext ()->Devices .size () > 1 ) {
675
- if (MemMigrationEvents.size ()) {
676
- UR_CHECK_ERROR (enqueueEventsWait (hQueue, CuStream,
677
- MemMigrationEvents.size (),
678
- MemMigrationEvents.data ()));
679
- }
680
600
for (auto &MemArg : hKernel->Args .MemObjArgs ) {
681
601
enqueueMigrateMemoryToDeviceIfNeeded (MemArg.Mem , hQueue->getDevice (),
682
602
CuStream);
603
+ if (MemArg.AccessFlags &
604
+ (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
605
+ MemArg.Mem ->setLastQueueWritingToMemObj (hQueue);
606
+ }
683
607
}
684
608
}
685
609
@@ -690,20 +614,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
690
614
UR_CHECK_ERROR (RetImplEvent->start ());
691
615
}
692
616
693
- // Once event has been started we can unlock MemoryMigrationMutex
694
- if (phEvent && hQueue->getContext ()->Devices .size () > 1 ) {
695
- for (auto &MemArg : hKernel->Args .MemObjArgs ) {
696
- // Telling the ur_mem_handle_t that it will need to wait on this kernel
697
- // if it has been written to
698
- if (MemArg.AccessFlags &
699
- (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
700
- MemArg.Mem ->setLastEventWritingToMemObj (RetImplEvent.get ());
701
- }
702
- }
703
- // We can release the MemoryMigrationMutexes now
704
- MemMigrationLocks.clear ();
705
- }
706
-
707
617
auto &ArgIndices = hKernel->getArgIndices ();
708
618
709
619
CUlaunchConfig launch_config;
@@ -824,28 +734,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
824
734
ur_event_handle_t *phEvent) {
825
735
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
826
736
827
- ur_lock MemoryMigrationLock{hBuffer->MemoryMigrationMutex };
828
- auto Device = hQueue->getDevice ();
829
- ScopedContext Active (Device);
830
- CUstream Stream = hQueue->getNextTransferStream ();
831
-
832
737
try {
833
738
// Note that this entry point may be called on a queue that may not be the
834
739
// last queue to write to the MemBuffer, meaning we must perform the copy
835
740
// from a different device
836
- if (hBuffer->LastEventWritingToMemObj &&
837
- hBuffer->LastEventWritingToMemObj ->getQueue ()->getDevice () !=
838
- hQueue->getDevice ()) {
839
- hQueue = hBuffer->LastEventWritingToMemObj ->getQueue ();
840
- Device = hQueue->getDevice ();
841
- ScopedContext Active (Device);
842
- Stream = CUstream{0 }; // Default stream for different device
843
- // We may have to wait for an event on another queue if it is the last
844
- // event writing to mem obj
845
- UR_CHECK_ERROR (enqueueEventsWait (hQueue, Stream, 1 ,
846
- &hBuffer->LastEventWritingToMemObj ));
741
+ if (hBuffer->LastQueueWritingToMemObj &&
742
+ hBuffer->LastQueueWritingToMemObj ->getDevice () != hQueue->getDevice ()) {
743
+ hQueue = hBuffer->LastQueueWritingToMemObj ;
847
744
}
848
745
746
+ auto Device = hQueue->getDevice ();
747
+ ScopedContext Active (Device);
748
+ CUstream Stream = hQueue->getNextTransferStream ();
749
+
849
750
UR_CHECK_ERROR (enqueueEventsWait (hQueue, Stream, numEventsInWaitList,
850
751
phEventWaitList));
851
752
@@ -890,7 +791,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
890
791
CUdeviceptr DevPtr =
891
792
std::get<BufferMem>(hBuffer->Mem ).getPtr (hQueue->getDevice ());
892
793
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
893
- ur_lock MemMigrationLock{ hBuffer->MemoryMigrationMutex } ;
794
+ hBuffer->setLastQueueWritingToMemObj (hQueue) ;
894
795
895
796
try {
896
797
ScopedContext Active (hQueue->getDevice ());
@@ -920,7 +821,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
920
821
921
822
if (phEvent) {
922
823
*phEvent = RetImplEvent.release ();
923
- hBuffer->setLastEventWritingToMemObj (*phEvent);
924
824
}
925
825
} catch (ur_result_t Err) {
926
826
return Err;
@@ -1060,7 +960,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
1060
960
UR_ASSERT (size + offset <= std::get<BufferMem>(hBuffer->Mem ).getSize (),
1061
961
UR_RESULT_ERROR_INVALID_SIZE);
1062
962
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
1063
- ur_lock MemMigrationLock{ hBuffer->MemoryMigrationMutex } ;
963
+ hBuffer->setLastQueueWritingToMemObj (hQueue) ;
1064
964
1065
965
try {
1066
966
ScopedContext Active (hQueue->getDevice ());
@@ -1107,7 +1007,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
1107
1007
if (phEvent) {
1108
1008
UR_CHECK_ERROR (RetImplEvent->record ());
1109
1009
*phEvent = RetImplEvent.release ();
1110
- hBuffer->setLastEventWritingToMemObj (*phEvent);
1111
1010
}
1112
1011
} catch (ur_result_t Err) {
1113
1012
return Err;
@@ -1215,28 +1114,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
1215
1114
1216
1115
UR_ASSERT (hImage->isImage (), UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1217
1116
1218
- ur_lock MemoryMigrationLock{hImage->MemoryMigrationMutex };
1219
- auto Device = hQueue->getDevice ();
1220
- CUstream Stream = hQueue->getNextTransferStream ();
1221
-
1222
1117
try {
1223
1118
// Note that this entry point may be called on a queue that may not be the
1224
1119
// last queue to write to the Image, meaning we must perform the copy
1225
1120
// from a different device
1226
- if (hImage->LastEventWritingToMemObj &&
1227
- hImage->LastEventWritingToMemObj ->getQueue ()->getDevice () !=
1228
- hQueue->getDevice ()) {
1229
- hQueue = hImage->LastEventWritingToMemObj ->getQueue ();
1230
- Device = hQueue->getDevice ();
1231
- ScopedContext Active (Device);
1232
- Stream = CUstream{0 }; // Default stream for different device
1233
- // We may have to wait for an event on another queue if it is the last
1234
- // event writing to mem obj
1235
- UR_CHECK_ERROR (enqueueEventsWait (hQueue, Stream, 1 ,
1236
- &hImage->LastEventWritingToMemObj ));
1121
+ if (hImage->LastQueueWritingToMemObj &&
1122
+ hImage->LastQueueWritingToMemObj ->getDevice () != hQueue->getDevice ()) {
1123
+ hQueue = hImage->LastQueueWritingToMemObj ;
1237
1124
}
1238
1125
1126
+ auto Device = hQueue->getDevice ();
1239
1127
ScopedContext Active (Device);
1128
+ CUstream Stream = hQueue->getNextTransferStream ();
1129
+
1240
1130
UR_CHECK_ERROR (enqueueEventsWait (hQueue, Stream, numEventsInWaitList,
1241
1131
phEventWaitList));
1242
1132
@@ -1839,28 +1729,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
1839
1729
UR_ASSERT (offset + size <= std::get<BufferMem>(hBuffer->Mem ).Size ,
1840
1730
UR_RESULT_ERROR_INVALID_SIZE);
1841
1731
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
1842
- ur_lock MemoryMigrationLock{hBuffer->MemoryMigrationMutex };
1843
- auto Device = hQueue->getDevice ();
1844
- ScopedContext Active (Device);
1845
- CUstream Stream = hQueue->getNextTransferStream ();
1846
1732
1847
1733
try {
1848
1734
// Note that this entry point may be called on a queue that may not be the
1849
1735
// last queue to write to the MemBuffer, meaning we must perform the copy
1850
1736
// from a different device
1851
- if (hBuffer->LastEventWritingToMemObj &&
1852
- hBuffer->LastEventWritingToMemObj ->getQueue ()->getDevice () !=
1853
- hQueue->getDevice ()) {
1854
- hQueue = hBuffer->LastEventWritingToMemObj ->getQueue ();
1855
- Device = hQueue->getDevice ();
1856
- ScopedContext Active (Device);
1857
- Stream = CUstream{0 }; // Default stream for different device
1858
- // We may have to wait for an event on another queue if it is the last
1859
- // event writing to mem obj
1860
- UR_CHECK_ERROR (enqueueEventsWait (hQueue, Stream, 1 ,
1861
- &hBuffer->LastEventWritingToMemObj ));
1737
+ if (hBuffer->LastQueueWritingToMemObj &&
1738
+ hBuffer->LastQueueWritingToMemObj ->getDevice () != hQueue->getDevice ()) {
1739
+ hQueue = hBuffer->LastQueueWritingToMemObj ;
1862
1740
}
1863
1741
1742
+ auto Device = hQueue->getDevice ();
1743
+ ScopedContext Active (Device);
1744
+ CUstream Stream = hQueue->getNextTransferStream ();
1745
+
1864
1746
UR_CHECK_ERROR (enqueueEventsWait (hQueue, Stream, numEventsInWaitList,
1865
1747
phEventWaitList));
1866
1748
@@ -1905,7 +1787,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
1905
1787
CUdeviceptr DevPtr =
1906
1788
std::get<BufferMem>(hBuffer->Mem ).getPtr (hQueue->getDevice ());
1907
1789
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
1908
- ur_lock MemMigrationLock{ hBuffer->MemoryMigrationMutex } ;
1790
+ hBuffer->setLastQueueWritingToMemObj (hQueue) ;
1909
1791
1910
1792
try {
1911
1793
ScopedContext Active (hQueue->getDevice ());
@@ -1933,7 +1815,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
1933
1815
1934
1816
if (phEvent) {
1935
1817
*phEvent = RetImplEvent.release ();
1936
- hBuffer->setLastEventWritingToMemObj (*phEvent);
1937
1818
}
1938
1819
} catch (ur_result_t Err) {
1939
1820
return Err;
0 commit comments