@@ -499,7 +499,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
499
499
}
500
500
501
501
// Preconditions
502
- UR_ASSERT (hQueue->getContext () == hKernel->getContext (),
502
+ UR_ASSERT (hQueue->getDevice () == hKernel->getProgram ()-> getDevice (),
503
503
UR_RESULT_ERROR_INVALID_KERNEL);
504
504
UR_ASSERT (workDim > 0 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
505
505
UR_ASSERT (workDim < 4 , UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
@@ -538,6 +538,38 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
538
538
}
539
539
}
540
540
541
+ std::vector<ur_event_handle_t > DepEvents (
542
+ phEventWaitList, phEventWaitList + numEventsInWaitList);
543
+ std::vector<std::pair<ur_mem_handle_t , ur_lock>> MemMigrationLocks;
544
+
545
+ // phEventWaitList only contains events that are handed to UR by the SYCL
546
+ // runtime. However since UR handles memory dependencies within a context
547
+ // we may need to add more events to our dependent events list if the UR
548
+ // context contains multiple devices
549
+ if (hQueue->getContext ()->Devices .size () > 1 ) {
550
+ MemMigrationLocks.reserve (hKernel->Args .MemObjArgs .size ());
551
+ for (auto &MemArg : hKernel->Args .MemObjArgs ) {
552
+ bool PushBack = false ;
553
+ if (auto MemDepEvent = MemArg.Mem ->LastEventWritingToMemObj ;
554
+ MemDepEvent && std::find (DepEvents.begin (), DepEvents.end (),
555
+ MemDepEvent) == DepEvents.end ()) {
556
+ DepEvents.push_back (MemDepEvent);
557
+ PushBack = true ;
558
+ }
559
+ if ((MemArg.AccessFlags &
560
+ (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) ||
561
+ PushBack) {
562
+ if (std::find_if (MemMigrationLocks.begin (), MemMigrationLocks.end (),
563
+ [MemArg](auto &Lock) {
564
+ return Lock.first == MemArg.Mem ;
565
+ }) == MemMigrationLocks.end ())
566
+ MemMigrationLocks.emplace_back (
567
+ std::pair{MemArg.Mem , ur_lock{MemArg.Mem ->MemoryMigrationMutex }});
568
+ }
569
+ }
570
+ }
571
+
572
+ // Early exit for zero size kernel
541
573
if (*pGlobalWorkSize == 0 ) {
542
574
return urEnqueueEventsWaitWithBarrier (hQueue, numEventsInWaitList,
543
575
phEventWaitList, phEvent);
@@ -549,26 +581,37 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
549
581
size_t BlocksPerGrid[3 ] = {1u , 1u , 1u };
550
582
551
583
uint32_t LocalSize = hKernel->getLocalSize ();
552
- ur_result_t Result = UR_RESULT_SUCCESS;
553
584
CUfunction CuFunc = hKernel->get ();
554
585
555
- Result = setKernelParams (hQueue->getContext (), hQueue->Device , workDim,
556
- nullptr , pGlobalWorkSize, pLocalWorkSize, hKernel,
557
- CuFunc, ThreadsPerBlock, BlocksPerGrid);
558
- if (Result != UR_RESULT_SUCCESS) {
559
- return Result;
560
- }
586
+ // This might return UR_RESULT_ERROR_ADAPTER_SPECIFIC, which cannot be handled
587
+ // using the standard UR_CHECK_ERROR
588
+ if (ur_result_t Ret =
589
+ setKernelParams (hQueue->getContext (), hQueue->Device , workDim,
590
+ nullptr , pGlobalWorkSize, pLocalWorkSize,
591
+ hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
592
+ Ret != UR_RESULT_SUCCESS)
593
+ return Ret;
561
594
562
595
try {
563
596
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
564
597
598
+ ScopedContext Active (hQueue->getDevice ());
565
599
uint32_t StreamToken;
566
600
ur_stream_guard_ Guard;
567
601
CUstream CuStream = hQueue->getNextComputeStream (
568
602
numEventsInWaitList, phEventWaitList, Guard, &StreamToken);
569
603
570
- Result = enqueueEventsWait (hQueue, CuStream, numEventsInWaitList,
571
- phEventWaitList);
604
+ if (DepEvents.size ()) {
605
+ UR_CHECK_ERROR (enqueueEventsWait (hQueue, CuStream, DepEvents.size (),
606
+ DepEvents.data ()));
607
+ }
608
+
609
+ // For memory migration across devices in the same context
610
+ if (hQueue->getContext ()->Devices .size () > 1 ) {
611
+ for (auto &MemArg : hKernel->Args .MemObjArgs ) {
612
+ migrateMemoryToDeviceIfNeeded (MemArg.Mem , hQueue->getDevice ());
613
+ }
614
+ }
572
615
573
616
if (phEvent) {
574
617
RetImplEvent =
@@ -577,6 +620,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
577
620
UR_CHECK_ERROR (RetImplEvent->start ());
578
621
}
579
622
623
+ // Once event has been started we can unlock MemoryMigrationMutex
624
+ if (hQueue->getContext ()->Devices .size () > 1 ) {
625
+ for (auto &MemArg : hKernel->Args .MemObjArgs ) {
626
+ // Telling the ur_mem_handle_t that it will need to wait on this kernel
627
+ // if it has been written to
628
+ if (phEvent && (MemArg.AccessFlags &
629
+ (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY))) {
630
+ MemArg.Mem ->setLastEventWritingToMemObj (RetImplEvent.get ());
631
+ }
632
+ }
633
+ // We can release the MemoryMigrationMutexes now
634
+ MemMigrationLocks.clear ();
635
+ }
636
+
580
637
auto &ArgIndices = hKernel->getArgIndices ();
581
638
582
639
CUlaunchConfig launch_config;
@@ -605,12 +662,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
605
662
}
606
663
607
664
} catch (ur_result_t Err) {
608
- Result = Err;
665
+ return Err;
609
666
}
610
- return Result ;
667
+ return UR_RESULT_SUCCESS ;
611
668
}
612
669
613
-
614
670
// / Set parameters for general 3D memory copy.
615
671
// / If the source and/or destination is on the device, SrcPtr and/or DstPtr
616
672
// / must be a pointer to a CUdeviceptr
0 commit comments