@@ -160,42 +160,41 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
160
160
UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST);
161
161
UR_ASSERT (hBuffer->isBuffer (), UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST);
162
162
163
- ur_result_t Result = UR_RESULT_SUCCESS;
164
- std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
163
+ ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex };
165
164
166
165
try {
167
166
ScopedContext Active (hQueue->getDevice ());
168
167
hipStream_t HIPStream = hQueue->getNextTransferStream ();
169
168
UR_CHECK_ERROR (enqueueEventsWait (hQueue, HIPStream, numEventsInWaitList,
170
169
phEventWaitList));
171
170
172
- if (phEvent) {
173
- RetImplEvent =
174
- std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
175
- UR_COMMAND_MEM_BUFFER_WRITE, hQueue, HIPStream));
176
- UR_CHECK_ERROR (RetImplEvent->start ());
177
- }
171
+ // With multi dev ctx we have no choice but to record this event
172
+ std::unique_ptr<ur_event_handle_t_> RetImplEvent =
173
+ std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
174
+ UR_COMMAND_MEM_BUFFER_WRITE, hQueue, HIPStream));
175
+ UR_CHECK_ERROR (RetImplEvent->start ());
178
176
179
177
UR_CHECK_ERROR (
180
178
hipMemcpyHtoDAsync (std::get<BufferMem>(hBuffer->Mem )
181
179
.getPtrWithOffset (hQueue->getDevice (), offset),
182
180
const_cast <void *>(pSrc), size, HIPStream));
183
181
184
- if (phEvent) {
185
- UR_CHECK_ERROR (RetImplEvent->record ());
186
- }
182
+ UR_CHECK_ERROR (RetImplEvent->record ());
187
183
188
184
if (blockingWrite) {
189
185
UR_CHECK_ERROR (hipStreamSynchronize (HIPStream));
190
186
}
191
187
192
188
if (phEvent) {
193
189
*phEvent = RetImplEvent.release ();
190
+ hBuffer->setLastEventWritingToMemObj (*phEvent);
191
+ } else {
192
+ hBuffer->setLastEventWritingToMemObj (RetImplEvent.release ());
194
193
}
195
194
} catch (ur_result_t Err) {
196
- Result = Err;
195
+ return Err;
197
196
}
198
- return Result ;
197
+ return UR_RESULT_SUCCESS ;
199
198
}
200
199
201
200
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead (
@@ -656,44 +655,43 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
656
655
size_t hostRowPitch, size_t hostSlicePitch, void *pSrc,
657
656
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
658
657
ur_event_handle_t *phEvent) {
659
- ur_result_t Result = UR_RESULT_SUCCESS;
660
658
void *DevPtr = std::get<BufferMem>(hBuffer->Mem ).getVoid (hQueue->getDevice ());
661
- std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
659
+
660
+ ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex };
662
661
663
662
try {
664
663
ScopedContext Active (hQueue->getDevice ());
665
664
hipStream_t HIPStream = hQueue->getNextTransferStream ();
666
- Result = enqueueEventsWait (hQueue, HIPStream, numEventsInWaitList,
667
- phEventWaitList);
665
+ UR_CHECK_ERROR ( enqueueEventsWait (hQueue, HIPStream, numEventsInWaitList,
666
+ phEventWaitList) );
668
667
669
- if (phEvent) {
670
- RetImplEvent =
671
- std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
672
- UR_COMMAND_MEM_BUFFER_WRITE_RECT, hQueue, HIPStream));
673
- UR_CHECK_ERROR (RetImplEvent->start ());
674
- }
668
+ // With multi dev ctx we have no choice but to record this event
669
+ std::unique_ptr<ur_event_handle_t_> RetImplEvent =
670
+ std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
671
+ UR_COMMAND_MEM_BUFFER_WRITE, hQueue, HIPStream));
672
+ UR_CHECK_ERROR (RetImplEvent->start ());
675
673
676
- Result = commonEnqueueMemBufferCopyRect (
674
+ UR_CHECK_ERROR ( commonEnqueueMemBufferCopyRect (
677
675
HIPStream, region, pSrc, hipMemoryTypeHost, hostOrigin, hostRowPitch,
678
676
hostSlicePitch, &DevPtr, hipMemoryTypeDevice, bufferOrigin,
679
- bufferRowPitch, bufferSlicePitch);
677
+ bufferRowPitch, bufferSlicePitch)) ;
680
678
681
- if (phEvent) {
682
- UR_CHECK_ERROR (RetImplEvent->record ());
683
- }
679
+ UR_CHECK_ERROR (RetImplEvent->record ());
684
680
685
681
if (blockingWrite) {
686
682
UR_CHECK_ERROR (hipStreamSynchronize (HIPStream));
687
683
}
688
684
689
685
if (phEvent) {
690
686
*phEvent = RetImplEvent.release ();
687
+ hBuffer->setLastEventWritingToMemObj (*phEvent);
688
+ } else {
689
+ hBuffer->setLastEventWritingToMemObj (RetImplEvent.release ());
691
690
}
692
-
693
691
} catch (ur_result_t Err) {
694
- Result = Err;
692
+ return Err;
695
693
}
696
- return Result ;
694
+ return UR_RESULT_SUCCESS ;
697
695
}
698
696
699
697
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy (
@@ -885,24 +883,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
885
883
std::ignore = PatternIsValid;
886
884
std::ignore = PatternSizeIsValid;
887
885
888
- std::unique_ptr<ur_event_handle_t_> RetImplEvent{ nullptr };
886
+ ur_lock MemMigrationLock{hBuffer-> MemoryMigrationMutex };
889
887
890
888
try {
891
889
ScopedContext Active (hQueue->getDevice ());
892
890
893
891
auto Stream = hQueue->getNextTransferStream ();
894
- ur_result_t Result = UR_RESULT_SUCCESS;
895
892
if (phEventWaitList) {
896
- Result = enqueueEventsWait (hQueue, Stream, numEventsInWaitList,
897
- phEventWaitList);
893
+ UR_CHECK_ERROR ( enqueueEventsWait (hQueue, Stream, numEventsInWaitList,
894
+ phEventWaitList) );
898
895
}
899
896
900
- if (phEvent) {
901
- RetImplEvent =
902
- std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
903
- UR_COMMAND_MEM_BUFFER_FILL, hQueue, Stream));
904
- UR_CHECK_ERROR (RetImplEvent->start ());
905
- }
897
+ // With multi dev ctx we have no choice but to record this event
898
+ std::unique_ptr<ur_event_handle_t_> RetImplEvent =
899
+ std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
900
+ UR_COMMAND_MEM_BUFFER_WRITE, hQueue, Stream));
901
+ UR_CHECK_ERROR (RetImplEvent->start ());
906
902
907
903
auto DstDevice = std::get<BufferMem>(hBuffer->Mem )
908
904
.getPtrWithOffset (hQueue->getDevice (), offset);
@@ -927,23 +923,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
927
923
}
928
924
929
925
default : {
930
- Result = commonMemSetLargePattern (Stream, patternSize, size, pPattern ,
931
- DstDevice);
926
+ UR_CHECK_ERROR ( commonMemSetLargePattern (Stream, patternSize, size,
927
+ pPattern, DstDevice) );
932
928
break ;
933
929
}
934
930
}
935
931
932
+ UR_CHECK_ERROR (RetImplEvent->record ());
936
933
if (phEvent) {
937
- UR_CHECK_ERROR (RetImplEvent->record ());
938
934
*phEvent = RetImplEvent.release ();
935
+ hBuffer->setLastEventWritingToMemObj (*phEvent);
936
+ } else {
937
+ hBuffer->setLastEventWritingToMemObj (RetImplEvent.release ());
939
938
}
940
-
941
- return Result;
942
939
} catch (ur_result_t Err) {
943
940
return Err;
944
941
} catch (...) {
945
942
return UR_RESULT_ERROR_UNKNOWN;
946
943
}
944
+ return UR_RESULT_SUCCESS;
947
945
}
948
946
949
947
// / General ND memory copy operation for images (where N > 1).
0 commit comments