@@ -414,8 +414,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
414414 UR_ASSERT (workDim > 0 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
415415 UR_ASSERT (workDim < 4 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
416416
417- std::vector<ur_event_handle_t > DepEvents (
418- phEventWaitList, phEventWaitList + numEventsInWaitList);
417+ std::vector<ur_event_handle_t > MemMigrationEvents;
419418 std::vector<std::pair<ur_mem_handle_t , ur_lock>> MemMigrationLocks;
420419
421420 // phEventWaitList only contains events that are handed to UR by the SYCL
@@ -427,9 +426,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
427426 for (auto &MemArg : hKernel->Args .MemObjArgs ) {
428427 bool PushBack = false ;
429428 if (auto MemDepEvent = MemArg.Mem ->LastEventWritingToMemObj ;
430- MemDepEvent && std::find (DepEvents. begin (), DepEvents. end () ,
431- MemDepEvent) == DepEvents. end ( )) {
432- DepEvents .push_back (MemDepEvent);
429+ MemDepEvent && ! listContainsElem (numEventsInWaitList, phEventWaitList ,
430+ MemDepEvent )) {
431+ MemMigrationEvents .push_back (MemDepEvent);
433432 PushBack = true ;
434433 }
435434 if ((MemArg.AccessFlags &
@@ -477,19 +476,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
477476 CUstream CuStream = hQueue->getNextComputeStream (
478477 numEventsInWaitList, phEventWaitList, Guard, &StreamToken);
479478
480- if (DepEvents.size ()) {
481- UR_CHECK_ERROR (enqueueEventsWait (hQueue, CuStream, DepEvents.size (),
482- DepEvents.data ()));
483- }
479+ UR_CHECK_ERROR (enqueueEventsWait (hQueue, CuStream, numEventsInWaitList,
480+ phEventWaitList));
484481
485482 // For memory migration across devices in the same context
486483 if (hQueue->getContext ()->Devices .size () > 1 ) {
484+ if (MemMigrationEvents.size ()) {
485+ UR_CHECK_ERROR (enqueueEventsWait (hQueue, CuStream,
486+ MemMigrationEvents.size (),
487+ MemMigrationEvents.data ()));
488+ }
487489 for (auto &MemArg : hKernel->Args .MemObjArgs ) {
488- migrateMemoryToDeviceIfNeeded (MemArg.Mem , hQueue->getDevice ());
490+ enqueueMigrateMemoryToDeviceIfNeeded (MemArg.Mem , hQueue->getDevice (),
491+ CuStream);
489492 }
490493 }
491494
492- if (phEvent) {
495+ if (phEvent || MemMigrationEvents. size () ) {
493496 RetImplEvent =
494497 std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
495498 UR_COMMAND_KERNEL_LAUNCH, hQueue, CuStream, StreamToken));
@@ -522,8 +525,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
522525 if (phEvent) {
523526 UR_CHECK_ERROR (RetImplEvent->record ());
524527 *phEvent = RetImplEvent.release ();
528+ } else if (MemMigrationEvents.size ()) {
529+ UR_CHECK_ERROR (RetImplEvent->record ());
530+ for (auto &MemArg : hKernel->Args .MemObjArgs ) {
531+ // If no event is passed to entry point, we still need to have an event
532+ // if ur_mem_handle_t s are used. Here we give ownership of the event
533+ // to the ur_mem_handle_t
534+ if (MemArg.AccessFlags &
535+ (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
536+ MemArg.Mem ->setLastEventWritingToMemObj (RetImplEvent.release ());
537+ }
538+ }
525539 }
526-
527540 } catch (ur_result_t Err) {
528541 return Err;
529542 }
@@ -603,8 +616,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
603616 }
604617 }
605618
606- std::vector<ur_event_handle_t > DepEvents (
607- phEventWaitList, phEventWaitList + numEventsInWaitList);
619+ std::vector<ur_event_handle_t > MemMigrationEvents;
608620 std::vector<std::pair<ur_mem_handle_t , ur_lock>> MemMigrationLocks;
609621
610622 // phEventWaitList only contains events that are handed to UR by the SYCL
@@ -616,9 +628,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
616628 for (auto &MemArg : hKernel->Args .MemObjArgs ) {
617629 bool PushBack = false ;
618630 if (auto MemDepEvent = MemArg.Mem ->LastEventWritingToMemObj ;
619- MemDepEvent && std::find (DepEvents. begin (), DepEvents. end () ,
620- MemDepEvent) == DepEvents. end ( )) {
621- DepEvents .push_back (MemDepEvent);
631+ MemDepEvent && ! listContainsElem (numEventsInWaitList, phEventWaitList ,
632+ MemDepEvent )) {
633+ MemMigrationEvents .push_back (MemDepEvent);
622634 PushBack = true ;
623635 }
624636 if ((MemArg.AccessFlags &
@@ -666,19 +678,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
666678 CUstream CuStream = hQueue->getNextComputeStream (
667679 numEventsInWaitList, phEventWaitList, Guard, &StreamToken);
668680
669- if (DepEvents.size ()) {
670- UR_CHECK_ERROR (enqueueEventsWait (hQueue, CuStream, DepEvents.size (),
671- DepEvents.data ()));
672- }
681+ UR_CHECK_ERROR (enqueueEventsWait (hQueue, CuStream, numEventsInWaitList,
682+ phEventWaitList));
673683
674684 // For memory migration across devices in the same context
675685 if (hQueue->getContext ()->Devices .size () > 1 ) {
686+ if (MemMigrationEvents.size ()) {
687+ UR_CHECK_ERROR (enqueueEventsWait (hQueue, CuStream,
688+ MemMigrationEvents.size (),
689+ MemMigrationEvents.data ()));
690+ }
676691 for (auto &MemArg : hKernel->Args .MemObjArgs ) {
677- migrateMemoryToDeviceIfNeeded (MemArg.Mem , hQueue->getDevice ());
692+ enqueueMigrateMemoryToDeviceIfNeeded (MemArg.Mem , hQueue->getDevice (),
693+ CuStream);
678694 }
679695 }
680696
681- if (phEvent) {
697+ if (phEvent || MemMigrationEvents. size () ) {
682698 RetImplEvent =
683699 std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
684700 UR_COMMAND_KERNEL_LAUNCH, hQueue, CuStream, StreamToken));
@@ -724,6 +740,17 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
724740 if (phEvent) {
725741 UR_CHECK_ERROR (RetImplEvent->record ());
726742 *phEvent = RetImplEvent.release ();
743+ } else if (MemMigrationEvents.size ()) {
744+ UR_CHECK_ERROR (RetImplEvent->record ());
745+ for (auto &MemArg : hKernel->Args .MemObjArgs ) {
746+ // If no event is passed to entry point, we still need to have an event
747+ // if ur_mem_handle_t s are used. Here we give ownership of the event
748+ // to the ur_mem_handle_t
749+ if (MemArg.AccessFlags &
750+ (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
751+ MemArg.Mem ->setLastEventWritingToMemObj (RetImplEvent.release ());
752+ }
753+ }
727754 }
728755
729756 } catch (ur_result_t Err) {
0 commit comments