@@ -414,8 +414,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
414
414
UR_ASSERT (workDim > 0 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
415
415
UR_ASSERT (workDim < 4 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
416
416
417
- std::vector<ur_event_handle_t > DepEvents (
418
- phEventWaitList, phEventWaitList + numEventsInWaitList);
417
+ std::vector<ur_event_handle_t > MemMigrationEvents;
419
418
std::vector<std::pair<ur_mem_handle_t , ur_lock>> MemMigrationLocks;
420
419
421
420
// phEventWaitList only contains events that are handed to UR by the SYCL
@@ -427,9 +426,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
427
426
for (auto &MemArg : hKernel->Args .MemObjArgs ) {
428
427
bool PushBack = false ;
429
428
if (auto MemDepEvent = MemArg.Mem ->LastEventWritingToMemObj ;
430
- MemDepEvent && std::find (DepEvents. begin (), DepEvents. end () ,
431
- MemDepEvent) == DepEvents. end ( )) {
432
- DepEvents .push_back (MemDepEvent);
429
+ MemDepEvent && ! listContainsElem (numEventsInWaitList, phEventWaitList ,
430
+ MemDepEvent )) {
431
+ MemMigrationEvents .push_back (MemDepEvent);
433
432
PushBack = true ;
434
433
}
435
434
if ((MemArg.AccessFlags &
@@ -477,19 +476,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
477
476
CUstream CuStream = hQueue->getNextComputeStream (
478
477
numEventsInWaitList, phEventWaitList, Guard, &StreamToken);
479
478
480
- if (DepEvents.size ()) {
481
- UR_CHECK_ERROR (enqueueEventsWait (hQueue, CuStream, DepEvents.size (),
482
- DepEvents.data ()));
483
- }
479
+ UR_CHECK_ERROR (enqueueEventsWait (hQueue, CuStream, numEventsInWaitList,
480
+ phEventWaitList));
484
481
485
482
// For memory migration across devices in the same context
486
483
if (hQueue->getContext ()->Devices .size () > 1 ) {
484
+ if (MemMigrationEvents.size ()) {
485
+ UR_CHECK_ERROR (enqueueEventsWait (hQueue, CuStream,
486
+ MemMigrationEvents.size (),
487
+ MemMigrationEvents.data ()));
488
+ }
487
489
for (auto &MemArg : hKernel->Args .MemObjArgs ) {
488
- migrateMemoryToDeviceIfNeeded (MemArg.Mem , hQueue->getDevice ());
490
+ enqueueMigrateMemoryToDeviceIfNeeded (MemArg.Mem , hQueue->getDevice (),
491
+ CuStream);
489
492
}
490
493
}
491
494
492
- if (phEvent) {
495
+ if (phEvent || MemMigrationEvents. size () ) {
493
496
RetImplEvent =
494
497
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
495
498
UR_COMMAND_KERNEL_LAUNCH, hQueue, CuStream, StreamToken));
@@ -522,8 +525,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
522
525
if (phEvent) {
523
526
UR_CHECK_ERROR (RetImplEvent->record ());
524
527
*phEvent = RetImplEvent.release ();
528
+ } else if (MemMigrationEvents.size ()) {
529
+ UR_CHECK_ERROR (RetImplEvent->record ());
530
+ for (auto &MemArg : hKernel->Args .MemObjArgs ) {
531
+ // If no event is passed to entry point, we still need to have an event
532
+ // if ur_mem_handle_t s are used. Here we give ownership of the event
533
+ // to the ur_mem_handle_t
534
+ if (MemArg.AccessFlags &
535
+ (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
536
+ MemArg.Mem ->setLastEventWritingToMemObj (RetImplEvent.release ());
537
+ }
538
+ }
525
539
}
526
-
527
540
} catch (ur_result_t Err) {
528
541
return Err;
529
542
}
@@ -603,8 +616,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
603
616
}
604
617
}
605
618
606
- std::vector<ur_event_handle_t > DepEvents (
607
- phEventWaitList, phEventWaitList + numEventsInWaitList);
619
+ std::vector<ur_event_handle_t > MemMigrationEvents;
608
620
std::vector<std::pair<ur_mem_handle_t , ur_lock>> MemMigrationLocks;
609
621
610
622
// phEventWaitList only contains events that are handed to UR by the SYCL
@@ -616,9 +628,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
616
628
for (auto &MemArg : hKernel->Args .MemObjArgs ) {
617
629
bool PushBack = false ;
618
630
if (auto MemDepEvent = MemArg.Mem ->LastEventWritingToMemObj ;
619
- MemDepEvent && std::find (DepEvents. begin (), DepEvents. end () ,
620
- MemDepEvent) == DepEvents. end ( )) {
621
- DepEvents .push_back (MemDepEvent);
631
+ MemDepEvent && ! listContainsElem (numEventsInWaitList, phEventWaitList ,
632
+ MemDepEvent )) {
633
+ MemMigrationEvents .push_back (MemDepEvent);
622
634
PushBack = true ;
623
635
}
624
636
if ((MemArg.AccessFlags &
@@ -666,19 +678,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
666
678
CUstream CuStream = hQueue->getNextComputeStream (
667
679
numEventsInWaitList, phEventWaitList, Guard, &StreamToken);
668
680
669
- if (DepEvents.size ()) {
670
- UR_CHECK_ERROR (enqueueEventsWait (hQueue, CuStream, DepEvents.size (),
671
- DepEvents.data ()));
672
- }
681
+ UR_CHECK_ERROR (enqueueEventsWait (hQueue, CuStream, numEventsInWaitList,
682
+ phEventWaitList));
673
683
674
684
// For memory migration across devices in the same context
675
685
if (hQueue->getContext ()->Devices .size () > 1 ) {
686
+ if (MemMigrationEvents.size ()) {
687
+ UR_CHECK_ERROR (enqueueEventsWait (hQueue, CuStream,
688
+ MemMigrationEvents.size (),
689
+ MemMigrationEvents.data ()));
690
+ }
676
691
for (auto &MemArg : hKernel->Args .MemObjArgs ) {
677
- migrateMemoryToDeviceIfNeeded (MemArg.Mem , hQueue->getDevice ());
692
+ enqueueMigrateMemoryToDeviceIfNeeded (MemArg.Mem , hQueue->getDevice (),
693
+ CuStream);
678
694
}
679
695
}
680
696
681
- if (phEvent) {
697
+ if (phEvent || MemMigrationEvents. size () ) {
682
698
RetImplEvent =
683
699
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
684
700
UR_COMMAND_KERNEL_LAUNCH, hQueue, CuStream, StreamToken));
@@ -724,6 +740,17 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
724
740
if (phEvent) {
725
741
UR_CHECK_ERROR (RetImplEvent->record ());
726
742
*phEvent = RetImplEvent.release ();
743
+ } else if (MemMigrationEvents.size ()) {
744
+ UR_CHECK_ERROR (RetImplEvent->record ());
745
+ for (auto &MemArg : hKernel->Args .MemObjArgs ) {
746
+ // If no event is passed to entry point, we still need to have an event
747
+ // if ur_mem_handle_t s are used. Here we give ownership of the event
748
+ // to the ur_mem_handle_t
749
+ if (MemArg.AccessFlags &
750
+ (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
751
+ MemArg.Mem ->setLastEventWritingToMemObj (RetImplEvent.release ());
752
+ }
753
+ }
727
754
}
728
755
729
756
} catch (ur_result_t Err) {
0 commit comments