@@ -492,20 +492,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
492
492
}
493
493
}
494
494
495
- if (phEvent || MemMigrationEvents. size () ) {
495
+ if (phEvent) {
496
496
RetImplEvent =
497
497
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
498
498
UR_COMMAND_KERNEL_LAUNCH, hQueue, CuStream, StreamToken));
499
499
UR_CHECK_ERROR (RetImplEvent->start ());
500
500
}
501
501
502
502
// Once event has been started we can unlock MemoryMigrationMutex
503
- if (hQueue->getContext ()->Devices .size () > 1 ) {
503
+ if (phEvent && hQueue->getContext ()->Devices .size () > 1 ) {
504
504
for (auto &MemArg : hKernel->Args .MemObjArgs ) {
505
505
// Telling the ur_mem_handle_t that it will need to wait on this kernel
506
506
// if it has been written to
507
- if (phEvent && ( MemArg.AccessFlags &
508
- (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY) )) {
507
+ if (MemArg.AccessFlags &
508
+ (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
509
509
MemArg.Mem ->setLastEventWritingToMemObj (RetImplEvent.get ());
510
510
}
511
511
}
@@ -525,17 +525,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
525
525
if (phEvent) {
526
526
UR_CHECK_ERROR (RetImplEvent->record ());
527
527
*phEvent = RetImplEvent.release ();
528
- } else if (MemMigrationEvents.size ()) {
529
- UR_CHECK_ERROR (RetImplEvent->record ());
530
- for (auto &MemArg : hKernel->Args .MemObjArgs ) {
531
- // If no event is passed to entry point, we still need to have an event
532
- // if ur_mem_handle_t s are used. Here we give ownership of the event
533
- // to the ur_mem_handle_t
534
- if (MemArg.AccessFlags &
535
- (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
536
- MemArg.Mem ->setLastEventWritingToMemObj (RetImplEvent.release ());
537
- }
538
- }
539
528
}
540
529
} catch (ur_result_t Err) {
541
530
return Err;
@@ -694,20 +683,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
694
683
}
695
684
}
696
685
697
- if (phEvent || MemMigrationEvents. size () ) {
686
+ if (phEvent) {
698
687
RetImplEvent =
699
688
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
700
689
UR_COMMAND_KERNEL_LAUNCH, hQueue, CuStream, StreamToken));
701
690
UR_CHECK_ERROR (RetImplEvent->start ());
702
691
}
703
692
704
693
// Once event has been started we can unlock MemoryMigrationMutex
705
- if (hQueue->getContext ()->Devices .size () > 1 ) {
694
+ if (phEvent && hQueue->getContext ()->Devices .size () > 1 ) {
706
695
for (auto &MemArg : hKernel->Args .MemObjArgs ) {
707
696
// Telling the ur_mem_handle_t that it will need to wait on this kernel
708
697
// if it has been written to
709
- if (phEvent && ( MemArg.AccessFlags &
710
- (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY) )) {
698
+ if (MemArg.AccessFlags &
699
+ (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
711
700
MemArg.Mem ->setLastEventWritingToMemObj (RetImplEvent.get ());
712
701
}
713
702
}
@@ -740,19 +729,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
740
729
if (phEvent) {
741
730
UR_CHECK_ERROR (RetImplEvent->record ());
742
731
*phEvent = RetImplEvent.release ();
743
- } else if (MemMigrationEvents.size ()) {
744
- UR_CHECK_ERROR (RetImplEvent->record ());
745
- for (auto &MemArg : hKernel->Args .MemObjArgs ) {
746
- // If no event is passed to entry point, we still need to have an event
747
- // if ur_mem_handle_t s are used. Here we give ownership of the event
748
- // to the ur_mem_handle_t
749
- if (MemArg.AccessFlags &
750
- (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
751
- MemArg.Mem ->setLastEventWritingToMemObj (RetImplEvent.release ());
752
- }
753
- }
754
732
}
755
-
756
733
} catch (ur_result_t Err) {
757
734
return Err;
758
735
}
@@ -912,6 +889,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
912
889
ur_event_handle_t *phEvent) {
913
890
CUdeviceptr DevPtr =
914
891
std::get<BufferMem>(hBuffer->Mem ).getPtr (hQueue->getDevice ());
892
+ std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
915
893
ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex };
916
894
917
895
try {
@@ -920,18 +898,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
920
898
UR_CHECK_ERROR (enqueueEventsWait (hQueue, cuStream, numEventsInWaitList,
921
899
phEventWaitList));
922
900
923
- // With multi dev ctx we have no choice but to record this event
924
- std::unique_ptr<ur_event_handle_t_> RetImplEvent =
925
- std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
926
- UR_COMMAND_MEM_BUFFER_WRITE_RECT, hQueue, cuStream));
927
- UR_CHECK_ERROR (RetImplEvent->start ());
901
+ if (phEvent) {
902
+ RetImplEvent =
903
+ std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
904
+ UR_COMMAND_MEM_BUFFER_WRITE_RECT, hQueue, cuStream));
905
+ UR_CHECK_ERROR (RetImplEvent->start ());
906
+ }
928
907
929
908
UR_CHECK_ERROR (commonEnqueueMemBufferCopyRect (
930
909
cuStream, region, pSrc, CU_MEMORYTYPE_HOST, hostOrigin, hostRowPitch,
931
910
hostSlicePitch, &DevPtr, CU_MEMORYTYPE_DEVICE, bufferOrigin,
932
911
bufferRowPitch, bufferSlicePitch));
933
912
934
- UR_CHECK_ERROR (RetImplEvent->record ());
913
+ if (phEvent) {
914
+ UR_CHECK_ERROR (RetImplEvent->record ());
915
+ }
935
916
936
917
if (blockingWrite) {
937
918
UR_CHECK_ERROR (cuStreamSynchronize (cuStream));
@@ -940,10 +921,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
940
921
if (phEvent) {
941
922
*phEvent = RetImplEvent.release ();
942
923
hBuffer->setLastEventWritingToMemObj (*phEvent);
943
- } else {
944
- hBuffer->setLastEventWritingToMemObj (RetImplEvent.release ());
945
924
}
946
-
947
925
} catch (ur_result_t Err) {
948
926
return Err;
949
927
}
@@ -1081,6 +1059,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
1081
1059
ur_event_handle_t *phEvent) {
1082
1060
UR_ASSERT (size + offset <= std::get<BufferMem>(hBuffer->Mem ).getSize (),
1083
1061
UR_RESULT_ERROR_INVALID_SIZE);
1062
+ std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
1084
1063
ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex };
1085
1064
1086
1065
try {
@@ -1090,11 +1069,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
1090
1069
UR_CHECK_ERROR (enqueueEventsWait (hQueue, Stream, numEventsInWaitList,
1091
1070
phEventWaitList));
1092
1071
1093
- // With multi dev ctx we have no choice but to record this event
1094
- std::unique_ptr<ur_event_handle_t_> RetImplEvent =
1095
- std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
1096
- UR_COMMAND_MEM_BUFFER_FILL, hQueue, Stream));
1097
- UR_CHECK_ERROR (RetImplEvent->start ());
1072
+ if (phEvent) {
1073
+ RetImplEvent =
1074
+ std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
1075
+ UR_COMMAND_MEM_BUFFER_WRITE_RECT, hQueue, Stream));
1076
+ UR_CHECK_ERROR (RetImplEvent->start ());
1077
+ }
1098
1078
1099
1079
auto DstDevice = std::get<BufferMem>(hBuffer->Mem )
1100
1080
.getPtrWithOffset (hQueue->getDevice (), offset);
@@ -1124,13 +1104,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
1124
1104
}
1125
1105
}
1126
1106
1127
- UR_CHECK_ERROR (RetImplEvent->record ());
1128
1107
if (phEvent) {
1108
+ UR_CHECK_ERROR (RetImplEvent->record ());
1129
1109
*phEvent = RetImplEvent.release ();
1130
1110
hBuffer->setLastEventWritingToMemObj (*phEvent);
1131
- } else {
1132
- // Give buffer ownership if no event used
1133
- hBuffer->setLastEventWritingToMemObj (RetImplEvent.release ());
1134
1111
}
1135
1112
} catch (ur_result_t Err) {
1136
1113
return Err;
@@ -1925,27 +1902,30 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
1925
1902
UR_ASSERT (offset + size <= std::get<BufferMem>(hBuffer->Mem ).Size ,
1926
1903
UR_RESULT_ERROR_INVALID_SIZE);
1927
1904
1928
- ur_result_t Result = UR_RESULT_SUCCESS;
1929
1905
CUdeviceptr DevPtr =
1930
1906
std::get<BufferMem>(hBuffer->Mem ).getPtr (hQueue->getDevice ());
1907
+ std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
1931
1908
ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex };
1932
1909
1933
1910
try {
1934
1911
ScopedContext Active (hQueue->getDevice ());
1935
1912
CUstream CuStream = hQueue->getNextTransferStream ();
1936
1913
1937
- Result = enqueueEventsWait (hQueue, CuStream, numEventsInWaitList,
1938
- phEventWaitList);
1914
+ UR_CHECK_ERROR ( enqueueEventsWait (hQueue, CuStream, numEventsInWaitList,
1915
+ phEventWaitList) );
1939
1916
1940
- // With multi dev ctx we have no choice but to record this event
1941
- std::unique_ptr<ur_event_handle_t_> RetImplEvent =
1942
- std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
1943
- UR_COMMAND_MEM_BUFFER_WRITE, hQueue, CuStream));
1944
- UR_CHECK_ERROR (RetImplEvent->start ());
1917
+ if (phEvent) {
1918
+ RetImplEvent =
1919
+ std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
1920
+ UR_COMMAND_MEM_BUFFER_WRITE, hQueue, CuStream));
1921
+ UR_CHECK_ERROR (RetImplEvent->start ());
1922
+ }
1945
1923
1946
1924
UR_CHECK_ERROR (cuMemcpyHtoDAsync (DevPtr + offset, pSrc, size, CuStream));
1947
1925
1948
- UR_CHECK_ERROR (RetImplEvent->record ());
1926
+ if (phEvent) {
1927
+ UR_CHECK_ERROR (RetImplEvent->record ());
1928
+ }
1949
1929
1950
1930
if (blockingWrite) {
1951
1931
UR_CHECK_ERROR (cuStreamSynchronize (CuStream));
@@ -1954,14 +1934,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
1954
1934
if (phEvent) {
1955
1935
*phEvent = RetImplEvent.release ();
1956
1936
hBuffer->setLastEventWritingToMemObj (*phEvent);
1957
- } else {
1958
- // Give buffer ownership if no event used
1959
- hBuffer->setLastEventWritingToMemObj (RetImplEvent.release ());
1960
1937
}
1961
1938
} catch (ur_result_t Err) {
1962
- Result = Err;
1939
+ return Err;
1963
1940
}
1964
- return Result ;
1941
+ return UR_RESULT_SUCCESS ;
1965
1942
}
1966
1943
1967
1944
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite (
0 commit comments