@@ -492,20 +492,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
492492 }
493493 }
494494
495- if (phEvent || MemMigrationEvents. size () ) {
495+ if (phEvent) {
496496 RetImplEvent =
497497 std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
498498 UR_COMMAND_KERNEL_LAUNCH, hQueue, CuStream, StreamToken));
499499 UR_CHECK_ERROR (RetImplEvent->start ());
500500 }
501501
502502 // Once event has been started we can unlock MemoryMigrationMutex
503- if (hQueue->getContext ()->Devices .size () > 1 ) {
503+ if (phEvent && hQueue->getContext ()->Devices .size () > 1 ) {
504504 for (auto &MemArg : hKernel->Args .MemObjArgs ) {
505505 // Telling the ur_mem_handle_t that it will need to wait on this kernel
506506 // if it has been written to
507- if (phEvent && ( MemArg.AccessFlags &
508- (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY) )) {
507+ if (MemArg.AccessFlags &
508+ (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
509509 MemArg.Mem ->setLastEventWritingToMemObj (RetImplEvent.get ());
510510 }
511511 }
@@ -525,17 +525,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
525525 if (phEvent) {
526526 UR_CHECK_ERROR (RetImplEvent->record ());
527527 *phEvent = RetImplEvent.release ();
528- } else if (MemMigrationEvents.size ()) {
529- UR_CHECK_ERROR (RetImplEvent->record ());
530- for (auto &MemArg : hKernel->Args .MemObjArgs ) {
531- // If no event is passed to entry point, we still need to have an event
532- // if ur_mem_handle_t s are used. Here we give ownership of the event
533- // to the ur_mem_handle_t
534- if (MemArg.AccessFlags &
535- (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
536- MemArg.Mem ->setLastEventWritingToMemObj (RetImplEvent.release ());
537- }
538- }
539528 }
540529 } catch (ur_result_t Err) {
541530 return Err;
@@ -694,20 +683,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
694683 }
695684 }
696685
697- if (phEvent || MemMigrationEvents. size () ) {
686+ if (phEvent) {
698687 RetImplEvent =
699688 std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
700689 UR_COMMAND_KERNEL_LAUNCH, hQueue, CuStream, StreamToken));
701690 UR_CHECK_ERROR (RetImplEvent->start ());
702691 }
703692
704693 // Once event has been started we can unlock MemoryMigrationMutex
705- if (hQueue->getContext ()->Devices .size () > 1 ) {
694+ if (phEvent && hQueue->getContext ()->Devices .size () > 1 ) {
706695 for (auto &MemArg : hKernel->Args .MemObjArgs ) {
707696 // Telling the ur_mem_handle_t that it will need to wait on this kernel
708697 // if it has been written to
709- if (phEvent && ( MemArg.AccessFlags &
710- (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY) )) {
698+ if (MemArg.AccessFlags &
699+ (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
711700 MemArg.Mem ->setLastEventWritingToMemObj (RetImplEvent.get ());
712701 }
713702 }
@@ -740,19 +729,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
740729 if (phEvent) {
741730 UR_CHECK_ERROR (RetImplEvent->record ());
742731 *phEvent = RetImplEvent.release ();
743- } else if (MemMigrationEvents.size ()) {
744- UR_CHECK_ERROR (RetImplEvent->record ());
745- for (auto &MemArg : hKernel->Args .MemObjArgs ) {
746- // If no event is passed to entry point, we still need to have an event
747- // if ur_mem_handle_t s are used. Here we give ownership of the event
748- // to the ur_mem_handle_t
749- if (MemArg.AccessFlags &
750- (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
751- MemArg.Mem ->setLastEventWritingToMemObj (RetImplEvent.release ());
752- }
753- }
754732 }
755-
756733 } catch (ur_result_t Err) {
757734 return Err;
758735 }
@@ -912,6 +889,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
912889 ur_event_handle_t *phEvent) {
913890 CUdeviceptr DevPtr =
914891 std::get<BufferMem>(hBuffer->Mem ).getPtr (hQueue->getDevice ());
892+ std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
915893 ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex };
916894
917895 try {
@@ -920,18 +898,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
920898 UR_CHECK_ERROR (enqueueEventsWait (hQueue, cuStream, numEventsInWaitList,
921899 phEventWaitList));
922900
923- // With multi dev ctx we have no choice but to record this event
924- std::unique_ptr<ur_event_handle_t_> RetImplEvent =
925- std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
926- UR_COMMAND_MEM_BUFFER_WRITE_RECT, hQueue, cuStream));
927- UR_CHECK_ERROR (RetImplEvent->start ());
901+ if (phEvent) {
902+ RetImplEvent =
903+ std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
904+ UR_COMMAND_MEM_BUFFER_WRITE_RECT, hQueue, cuStream));
905+ UR_CHECK_ERROR (RetImplEvent->start ());
906+ }
928907
929908 UR_CHECK_ERROR (commonEnqueueMemBufferCopyRect (
930909 cuStream, region, pSrc, CU_MEMORYTYPE_HOST, hostOrigin, hostRowPitch,
931910 hostSlicePitch, &DevPtr, CU_MEMORYTYPE_DEVICE, bufferOrigin,
932911 bufferRowPitch, bufferSlicePitch));
933912
934- UR_CHECK_ERROR (RetImplEvent->record ());
913+ if (phEvent) {
914+ UR_CHECK_ERROR (RetImplEvent->record ());
915+ }
935916
936917 if (blockingWrite) {
937918 UR_CHECK_ERROR (cuStreamSynchronize (cuStream));
@@ -940,10 +921,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
940921 if (phEvent) {
941922 *phEvent = RetImplEvent.release ();
942923 hBuffer->setLastEventWritingToMemObj (*phEvent);
943- } else {
944- hBuffer->setLastEventWritingToMemObj (RetImplEvent.release ());
945924 }
946-
947925 } catch (ur_result_t Err) {
948926 return Err;
949927 }
@@ -1081,6 +1059,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
10811059 ur_event_handle_t *phEvent) {
10821060 UR_ASSERT (size + offset <= std::get<BufferMem>(hBuffer->Mem ).getSize (),
10831061 UR_RESULT_ERROR_INVALID_SIZE);
1062+ std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
10841063 ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex };
10851064
10861065 try {
@@ -1090,11 +1069,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
10901069 UR_CHECK_ERROR (enqueueEventsWait (hQueue, Stream, numEventsInWaitList,
10911070 phEventWaitList));
10921071
1093- // With multi dev ctx we have no choice but to record this event
1094- std::unique_ptr<ur_event_handle_t_> RetImplEvent =
1095- std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
1096- UR_COMMAND_MEM_BUFFER_FILL, hQueue, Stream));
1097- UR_CHECK_ERROR (RetImplEvent->start ());
1072+ if (phEvent) {
1073+ RetImplEvent =
1074+ std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
1075+ UR_COMMAND_MEM_BUFFER_WRITE_RECT, hQueue, Stream));
1076+ UR_CHECK_ERROR (RetImplEvent->start ());
1077+ }
10981078
10991079 auto DstDevice = std::get<BufferMem>(hBuffer->Mem )
11001080 .getPtrWithOffset (hQueue->getDevice (), offset);
@@ -1124,13 +1104,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
11241104 }
11251105 }
11261106
1127- UR_CHECK_ERROR (RetImplEvent->record ());
11281107 if (phEvent) {
1108+ UR_CHECK_ERROR (RetImplEvent->record ());
11291109 *phEvent = RetImplEvent.release ();
11301110 hBuffer->setLastEventWritingToMemObj (*phEvent);
1131- } else {
1132- // Give buffer ownership if no event used
1133- hBuffer->setLastEventWritingToMemObj (RetImplEvent.release ());
11341111 }
11351112 } catch (ur_result_t Err) {
11361113 return Err;
@@ -1925,27 +1902,30 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
19251902 UR_ASSERT (offset + size <= std::get<BufferMem>(hBuffer->Mem ).Size ,
19261903 UR_RESULT_ERROR_INVALID_SIZE);
19271904
1928- ur_result_t Result = UR_RESULT_SUCCESS;
19291905 CUdeviceptr DevPtr =
19301906 std::get<BufferMem>(hBuffer->Mem ).getPtr (hQueue->getDevice ());
1907+ std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
19311908 ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex };
19321909
19331910 try {
19341911 ScopedContext Active (hQueue->getDevice ());
19351912 CUstream CuStream = hQueue->getNextTransferStream ();
19361913
1937- Result = enqueueEventsWait (hQueue, CuStream, numEventsInWaitList,
1938- phEventWaitList);
1914+ UR_CHECK_ERROR ( enqueueEventsWait (hQueue, CuStream, numEventsInWaitList,
1915+ phEventWaitList) );
19391916
1940- // With multi dev ctx we have no choice but to record this event
1941- std::unique_ptr<ur_event_handle_t_> RetImplEvent =
1942- std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
1943- UR_COMMAND_MEM_BUFFER_WRITE, hQueue, CuStream));
1944- UR_CHECK_ERROR (RetImplEvent->start ());
1917+ if (phEvent) {
1918+ RetImplEvent =
1919+ std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
1920+ UR_COMMAND_MEM_BUFFER_WRITE, hQueue, CuStream));
1921+ UR_CHECK_ERROR (RetImplEvent->start ());
1922+ }
19451923
19461924 UR_CHECK_ERROR (cuMemcpyHtoDAsync (DevPtr + offset, pSrc, size, CuStream));
19471925
1948- UR_CHECK_ERROR (RetImplEvent->record ());
1926+ if (phEvent) {
1927+ UR_CHECK_ERROR (RetImplEvent->record ());
1928+ }
19491929
19501930 if (blockingWrite) {
19511931 UR_CHECK_ERROR (cuStreamSynchronize (CuStream));
@@ -1954,14 +1934,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
19541934 if (phEvent) {
19551935 *phEvent = RetImplEvent.release ();
19561936 hBuffer->setLastEventWritingToMemObj (*phEvent);
1957- } else {
1958- // Give buffer ownership if no event used
1959- hBuffer->setLastEventWritingToMemObj (RetImplEvent.release ());
19601937 }
19611938 } catch (ur_result_t Err) {
1962- Result = Err;
1939+ return Err;
19631940 }
1964- return Result ;
1941+ return UR_RESULT_SUCCESS ;
19651942}
19661943
19671944UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite (
0 commit comments