@@ -160,42 +160,41 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
160160 UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST);
161161 UR_ASSERT (hBuffer->isBuffer (), UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST);
162162
163- ur_result_t Result = UR_RESULT_SUCCESS;
164- std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
163+ ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex };
165164
166165 try {
167166 ScopedContext Active (hQueue->getDevice ());
168167 hipStream_t HIPStream = hQueue->getNextTransferStream ();
169168 UR_CHECK_ERROR (enqueueEventsWait (hQueue, HIPStream, numEventsInWaitList,
170169 phEventWaitList));
171170
172- if (phEvent) {
173- RetImplEvent =
174- std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
175- UR_COMMAND_MEM_BUFFER_WRITE, hQueue, HIPStream));
176- UR_CHECK_ERROR (RetImplEvent->start ());
177- }
171+ // With multi dev ctx we have no choice but to record this event
172+ std::unique_ptr<ur_event_handle_t_> RetImplEvent =
173+ std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
174+ UR_COMMAND_MEM_BUFFER_WRITE, hQueue, HIPStream));
175+ UR_CHECK_ERROR (RetImplEvent->start ());
178176
179177 UR_CHECK_ERROR (
180178 hipMemcpyHtoDAsync (std::get<BufferMem>(hBuffer->Mem )
181179 .getPtrWithOffset (hQueue->getDevice (), offset),
182180 const_cast <void *>(pSrc), size, HIPStream));
183181
184- if (phEvent) {
185- UR_CHECK_ERROR (RetImplEvent->record ());
186- }
182+ UR_CHECK_ERROR (RetImplEvent->record ());
187183
188184 if (blockingWrite) {
189185 UR_CHECK_ERROR (hipStreamSynchronize (HIPStream));
190186 }
191187
192188 if (phEvent) {
193189 *phEvent = RetImplEvent.release ();
190+ hBuffer->setLastEventWritingToMemObj (*phEvent);
191+ } else {
192+ hBuffer->setLastEventWritingToMemObj (RetImplEvent.release ());
194193 }
195194 } catch (ur_result_t Err) {
196- Result = Err;
195+ return Err;
197196 }
198- return Result ;
197+ return UR_RESULT_SUCCESS ;
199198}
200199
201200UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead (
@@ -656,44 +655,43 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
656655 size_t hostRowPitch, size_t hostSlicePitch, void *pSrc,
657656 uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
658657 ur_event_handle_t *phEvent) {
659- ur_result_t Result = UR_RESULT_SUCCESS;
660658 void *DevPtr = std::get<BufferMem>(hBuffer->Mem ).getVoid (hQueue->getDevice ());
661- std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
659+
660+ ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex };
662661
663662 try {
664663 ScopedContext Active (hQueue->getDevice ());
665664 hipStream_t HIPStream = hQueue->getNextTransferStream ();
666- Result = enqueueEventsWait (hQueue, HIPStream, numEventsInWaitList,
667- phEventWaitList);
665+ UR_CHECK_ERROR ( enqueueEventsWait (hQueue, HIPStream, numEventsInWaitList,
666+ phEventWaitList) );
668667
669- if (phEvent) {
670- RetImplEvent =
671- std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
672- UR_COMMAND_MEM_BUFFER_WRITE_RECT, hQueue, HIPStream));
673- UR_CHECK_ERROR (RetImplEvent->start ());
674- }
668+ // With multi dev ctx we have no choice but to record this event
669+ std::unique_ptr<ur_event_handle_t_> RetImplEvent =
670+ std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
671+ UR_COMMAND_MEM_BUFFER_WRITE, hQueue, HIPStream));
672+ UR_CHECK_ERROR (RetImplEvent->start ());
675673
676- Result = commonEnqueueMemBufferCopyRect (
674+ UR_CHECK_ERROR ( commonEnqueueMemBufferCopyRect (
677675 HIPStream, region, pSrc, hipMemoryTypeHost, hostOrigin, hostRowPitch,
678676 hostSlicePitch, &DevPtr, hipMemoryTypeDevice, bufferOrigin,
679- bufferRowPitch, bufferSlicePitch);
677+ bufferRowPitch, bufferSlicePitch)) ;
680678
681- if (phEvent) {
682- UR_CHECK_ERROR (RetImplEvent->record ());
683- }
679+ UR_CHECK_ERROR (RetImplEvent->record ());
684680
685681 if (blockingWrite) {
686682 UR_CHECK_ERROR (hipStreamSynchronize (HIPStream));
687683 }
688684
689685 if (phEvent) {
690686 *phEvent = RetImplEvent.release ();
687+ hBuffer->setLastEventWritingToMemObj (*phEvent);
688+ } else {
689+ hBuffer->setLastEventWritingToMemObj (RetImplEvent.release ());
691690 }
692-
693691 } catch (ur_result_t Err) {
694- Result = Err;
692+ return Err;
695693 }
696- return Result ;
694+ return UR_RESULT_SUCCESS ;
697695}
698696
699697UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy (
@@ -885,24 +883,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
885883 std::ignore = PatternIsValid;
886884 std::ignore = PatternSizeIsValid;
887885
888- std::unique_ptr<ur_event_handle_t_> RetImplEvent{ nullptr };
886+ ur_lock MemMigrationLock{hBuffer-> MemoryMigrationMutex };
889887
890888 try {
891889 ScopedContext Active (hQueue->getDevice ());
892890
893891 auto Stream = hQueue->getNextTransferStream ();
894- ur_result_t Result = UR_RESULT_SUCCESS;
895892 if (phEventWaitList) {
896- Result = enqueueEventsWait (hQueue, Stream, numEventsInWaitList,
897- phEventWaitList);
893+ UR_CHECK_ERROR ( enqueueEventsWait (hQueue, Stream, numEventsInWaitList,
894+ phEventWaitList) );
898895 }
899896
900- if (phEvent) {
901- RetImplEvent =
902- std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
903- UR_COMMAND_MEM_BUFFER_FILL, hQueue, Stream));
904- UR_CHECK_ERROR (RetImplEvent->start ());
905- }
897+ // With multi dev ctx we have no choice but to record this event
898+ std::unique_ptr<ur_event_handle_t_> RetImplEvent =
899+ std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
900+ UR_COMMAND_MEM_BUFFER_WRITE, hQueue, Stream));
901+ UR_CHECK_ERROR (RetImplEvent->start ());
906902
907903 auto DstDevice = std::get<BufferMem>(hBuffer->Mem )
908904 .getPtrWithOffset (hQueue->getDevice (), offset);
@@ -927,23 +923,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
927923 }
928924
929925 default : {
930- Result = commonMemSetLargePattern (Stream, patternSize, size, pPattern ,
931- DstDevice);
926+ UR_CHECK_ERROR ( commonMemSetLargePattern (Stream, patternSize, size,
927+ pPattern, DstDevice) );
932928 break ;
933929 }
934930 }
935931
932+ UR_CHECK_ERROR (RetImplEvent->record ());
936933 if (phEvent) {
937- UR_CHECK_ERROR (RetImplEvent->record ());
938934 *phEvent = RetImplEvent.release ();
935+ hBuffer->setLastEventWritingToMemObj (*phEvent);
936+ } else {
937+ hBuffer->setLastEventWritingToMemObj (RetImplEvent.release ());
939938 }
940-
941- return Result;
942939 } catch (ur_result_t Err) {
943940 return Err;
944941 } catch (...) {
945942 return UR_RESULT_ERROR_UNKNOWN;
946943 }
944+ return UR_RESULT_SUCCESS;
947945}
948946
949947// / General ND memory copy operation for images (where N > 1).
0 commit comments