Skip to content

Commit 836ded1

Browse files
author
Hugh Delaney
committed
Make write and fill ops setLastEventWritingToMemObj
Write and fill entry points weren't setting LastEventWritingToMemObj, which made some UR tests fail.
1 parent 5c77ac8 commit 836ded1

File tree

2 files changed

+81
-81
lines changed

2 files changed

+81
-81
lines changed

source/adapters/cuda/enqueue.cpp

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -912,36 +912,36 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
912912
ur_event_handle_t *phEvent) {
913913
CUdeviceptr DevPtr =
914914
std::get<BufferMem>(hBuffer->Mem).getPtr(hQueue->getDevice());
915-
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
915+
ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex};
916916

917917
try {
918918
ScopedContext Active(hQueue->getDevice());
919919
CUstream cuStream = hQueue->getNextTransferStream();
920920
UR_CHECK_ERROR(enqueueEventsWait(hQueue, cuStream, numEventsInWaitList,
921921
phEventWaitList));
922922

923-
if (phEvent) {
924-
RetImplEvent =
925-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
926-
UR_COMMAND_MEM_BUFFER_WRITE_RECT, hQueue, cuStream));
927-
UR_CHECK_ERROR(RetImplEvent->start());
928-
}
923+
// With multi dev ctx we have no choice but to record this event
924+
std::unique_ptr<ur_event_handle_t_> RetImplEvent =
925+
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
926+
UR_COMMAND_MEM_BUFFER_WRITE_RECT, hQueue, cuStream));
927+
UR_CHECK_ERROR(RetImplEvent->start());
929928

930929
UR_CHECK_ERROR(commonEnqueueMemBufferCopyRect(
931930
cuStream, region, pSrc, CU_MEMORYTYPE_HOST, hostOrigin, hostRowPitch,
932931
hostSlicePitch, &DevPtr, CU_MEMORYTYPE_DEVICE, bufferOrigin,
933932
bufferRowPitch, bufferSlicePitch));
934933

935-
if (phEvent) {
936-
UR_CHECK_ERROR(RetImplEvent->record());
937-
}
934+
UR_CHECK_ERROR(RetImplEvent->record());
938935

939936
if (blockingWrite) {
940937
UR_CHECK_ERROR(cuStreamSynchronize(cuStream));
941938
}
942939

943940
if (phEvent) {
944941
*phEvent = RetImplEvent.release();
942+
hBuffer->setLastEventWritingToMemObj(*phEvent);
943+
} else {
944+
hBuffer->setLastEventWritingToMemObj(RetImplEvent.release());
945945
}
946946

947947
} catch (ur_result_t Err) {
@@ -1081,22 +1081,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
10811081
ur_event_handle_t *phEvent) {
10821082
UR_ASSERT(size + offset <= std::get<BufferMem>(hBuffer->Mem).getSize(),
10831083
UR_RESULT_ERROR_INVALID_SIZE);
1084-
1085-
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
1084+
ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex};
10861085

10871086
try {
10881087
ScopedContext Active(hQueue->getDevice());
10891088

10901089
auto Stream = hQueue->getNextTransferStream();
1091-
ur_result_t Result =
1092-
enqueueEventsWait(hQueue, Stream, numEventsInWaitList, phEventWaitList);
1090+
UR_CHECK_ERROR(enqueueEventsWait(hQueue, Stream, numEventsInWaitList,
1091+
phEventWaitList));
10931092

1094-
if (phEvent) {
1095-
RetImplEvent =
1096-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
1097-
UR_COMMAND_MEM_BUFFER_FILL, hQueue, Stream));
1098-
UR_CHECK_ERROR(RetImplEvent->start());
1099-
}
1093+
// With multi dev ctx we have no choice but to record this event
1094+
std::unique_ptr<ur_event_handle_t_> RetImplEvent =
1095+
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
1096+
UR_COMMAND_MEM_BUFFER_FILL, hQueue, Stream));
1097+
UR_CHECK_ERROR(RetImplEvent->start());
11001098

11011099
auto DstDevice = std::get<BufferMem>(hBuffer->Mem)
11021100
.getPtrWithOffset(hQueue->getDevice(), offset);
@@ -1120,23 +1118,26 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
11201118
break;
11211119
}
11221120
default: {
1123-
Result = commonMemSetLargePattern(Stream, patternSize, size, pPattern,
1124-
DstDevice);
1121+
UR_CHECK_ERROR(commonMemSetLargePattern(Stream, patternSize, size,
1122+
pPattern, DstDevice));
11251123
break;
11261124
}
11271125
}
11281126

1127+
UR_CHECK_ERROR(RetImplEvent->record());
11291128
if (phEvent) {
1130-
UR_CHECK_ERROR(RetImplEvent->record());
11311129
*phEvent = RetImplEvent.release();
1130+
hBuffer->setLastEventWritingToMemObj(*phEvent);
1131+
} else {
1132+
// Give buffer ownership if no event used
1133+
hBuffer->setLastEventWritingToMemObj(RetImplEvent.release());
11321134
}
1133-
1134-
return Result;
11351135
} catch (ur_result_t Err) {
11361136
return Err;
11371137
} catch (...) {
11381138
return UR_RESULT_ERROR_UNKNOWN;
11391139
}
1140+
return UR_RESULT_SUCCESS;
11401141
}
11411142

11421143
static size_t imageElementByteSize(CUDA_ARRAY_DESCRIPTOR ArrayDesc) {
@@ -1927,7 +1928,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
19271928
ur_result_t Result = UR_RESULT_SUCCESS;
19281929
CUdeviceptr DevPtr =
19291930
std::get<BufferMem>(hBuffer->Mem).getPtr(hQueue->getDevice());
1930-
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
1931+
ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex};
19311932

19321933
try {
19331934
ScopedContext Active(hQueue->getDevice());
@@ -1936,25 +1937,26 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
19361937
Result = enqueueEventsWait(hQueue, CuStream, numEventsInWaitList,
19371938
phEventWaitList);
19381939

1939-
if (phEvent) {
1940-
RetImplEvent =
1941-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
1942-
UR_COMMAND_MEM_BUFFER_WRITE, hQueue, CuStream));
1943-
UR_CHECK_ERROR(RetImplEvent->start());
1944-
}
1940+
// With multi dev ctx we have no choice but to record this event
1941+
std::unique_ptr<ur_event_handle_t_> RetImplEvent =
1942+
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
1943+
UR_COMMAND_MEM_BUFFER_WRITE, hQueue, CuStream));
1944+
UR_CHECK_ERROR(RetImplEvent->start());
19451945

19461946
UR_CHECK_ERROR(cuMemcpyHtoDAsync(DevPtr + offset, pSrc, size, CuStream));
19471947

1948-
if (phEvent) {
1949-
UR_CHECK_ERROR(RetImplEvent->record());
1950-
}
1948+
UR_CHECK_ERROR(RetImplEvent->record());
19511949

19521950
if (blockingWrite) {
19531951
UR_CHECK_ERROR(cuStreamSynchronize(CuStream));
19541952
}
19551953

19561954
if (phEvent) {
19571955
*phEvent = RetImplEvent.release();
1956+
hBuffer->setLastEventWritingToMemObj(*phEvent);
1957+
} else {
1958+
// Give buffer ownership if no event used
1959+
hBuffer->setLastEventWritingToMemObj(RetImplEvent.release());
19581960
}
19591961
} catch (ur_result_t Err) {
19601962
Result = Err;

source/adapters/hip/enqueue.cpp

Lines changed: 44 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -160,42 +160,41 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
160160
UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST);
161161
UR_ASSERT(hBuffer->isBuffer(), UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST);
162162

163-
ur_result_t Result = UR_RESULT_SUCCESS;
164-
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
163+
ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex};
165164

166165
try {
167166
ScopedContext Active(hQueue->getDevice());
168167
hipStream_t HIPStream = hQueue->getNextTransferStream();
169168
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
170169
phEventWaitList));
171170

172-
if (phEvent) {
173-
RetImplEvent =
174-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
175-
UR_COMMAND_MEM_BUFFER_WRITE, hQueue, HIPStream));
176-
UR_CHECK_ERROR(RetImplEvent->start());
177-
}
171+
// With multi dev ctx we have no choice but to record this event
172+
std::unique_ptr<ur_event_handle_t_> RetImplEvent =
173+
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
174+
UR_COMMAND_MEM_BUFFER_WRITE, hQueue, HIPStream));
175+
UR_CHECK_ERROR(RetImplEvent->start());
178176

179177
UR_CHECK_ERROR(
180178
hipMemcpyHtoDAsync(std::get<BufferMem>(hBuffer->Mem)
181179
.getPtrWithOffset(hQueue->getDevice(), offset),
182180
const_cast<void *>(pSrc), size, HIPStream));
183181

184-
if (phEvent) {
185-
UR_CHECK_ERROR(RetImplEvent->record());
186-
}
182+
UR_CHECK_ERROR(RetImplEvent->record());
187183

188184
if (blockingWrite) {
189185
UR_CHECK_ERROR(hipStreamSynchronize(HIPStream));
190186
}
191187

192188
if (phEvent) {
193189
*phEvent = RetImplEvent.release();
190+
hBuffer->setLastEventWritingToMemObj(*phEvent);
191+
} else {
192+
hBuffer->setLastEventWritingToMemObj(RetImplEvent.release());
194193
}
195194
} catch (ur_result_t Err) {
196-
Result = Err;
195+
return Err;
197196
}
198-
return Result;
197+
return UR_RESULT_SUCCESS;
199198
}
200199

201200
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
@@ -656,44 +655,43 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
656655
size_t hostRowPitch, size_t hostSlicePitch, void *pSrc,
657656
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
658657
ur_event_handle_t *phEvent) {
659-
ur_result_t Result = UR_RESULT_SUCCESS;
660658
void *DevPtr = std::get<BufferMem>(hBuffer->Mem).getVoid(hQueue->getDevice());
661-
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
659+
660+
ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex};
662661

663662
try {
664663
ScopedContext Active(hQueue->getDevice());
665664
hipStream_t HIPStream = hQueue->getNextTransferStream();
666-
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
667-
phEventWaitList);
665+
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
666+
phEventWaitList));
668667

669-
if (phEvent) {
670-
RetImplEvent =
671-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
672-
UR_COMMAND_MEM_BUFFER_WRITE_RECT, hQueue, HIPStream));
673-
UR_CHECK_ERROR(RetImplEvent->start());
674-
}
668+
// With multi dev ctx we have no choice but to record this event
669+
std::unique_ptr<ur_event_handle_t_> RetImplEvent =
670+
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
671+
UR_COMMAND_MEM_BUFFER_WRITE, hQueue, HIPStream));
672+
UR_CHECK_ERROR(RetImplEvent->start());
675673

676-
Result = commonEnqueueMemBufferCopyRect(
674+
UR_CHECK_ERROR(commonEnqueueMemBufferCopyRect(
677675
HIPStream, region, pSrc, hipMemoryTypeHost, hostOrigin, hostRowPitch,
678676
hostSlicePitch, &DevPtr, hipMemoryTypeDevice, bufferOrigin,
679-
bufferRowPitch, bufferSlicePitch);
677+
bufferRowPitch, bufferSlicePitch));
680678

681-
if (phEvent) {
682-
UR_CHECK_ERROR(RetImplEvent->record());
683-
}
679+
UR_CHECK_ERROR(RetImplEvent->record());
684680

685681
if (blockingWrite) {
686682
UR_CHECK_ERROR(hipStreamSynchronize(HIPStream));
687683
}
688684

689685
if (phEvent) {
690686
*phEvent = RetImplEvent.release();
687+
hBuffer->setLastEventWritingToMemObj(*phEvent);
688+
} else {
689+
hBuffer->setLastEventWritingToMemObj(RetImplEvent.release());
691690
}
692-
693691
} catch (ur_result_t Err) {
694-
Result = Err;
692+
return Err;
695693
}
696-
return Result;
694+
return UR_RESULT_SUCCESS;
697695
}
698696

699697
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
@@ -885,24 +883,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
885883
std::ignore = PatternIsValid;
886884
std::ignore = PatternSizeIsValid;
887885

888-
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
886+
ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex};
889887

890888
try {
891889
ScopedContext Active(hQueue->getDevice());
892890

893891
auto Stream = hQueue->getNextTransferStream();
894-
ur_result_t Result = UR_RESULT_SUCCESS;
895892
if (phEventWaitList) {
896-
Result = enqueueEventsWait(hQueue, Stream, numEventsInWaitList,
897-
phEventWaitList);
893+
UR_CHECK_ERROR(enqueueEventsWait(hQueue, Stream, numEventsInWaitList,
894+
phEventWaitList));
898895
}
899896

900-
if (phEvent) {
901-
RetImplEvent =
902-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
903-
UR_COMMAND_MEM_BUFFER_FILL, hQueue, Stream));
904-
UR_CHECK_ERROR(RetImplEvent->start());
905-
}
897+
// With multi dev ctx we have no choice but to record this event
898+
std::unique_ptr<ur_event_handle_t_> RetImplEvent =
899+
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
900+
UR_COMMAND_MEM_BUFFER_WRITE, hQueue, Stream));
901+
UR_CHECK_ERROR(RetImplEvent->start());
906902

907903
auto DstDevice = std::get<BufferMem>(hBuffer->Mem)
908904
.getPtrWithOffset(hQueue->getDevice(), offset);
@@ -927,23 +923,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
927923
}
928924

929925
default: {
930-
Result = commonMemSetLargePattern(Stream, patternSize, size, pPattern,
931-
DstDevice);
926+
UR_CHECK_ERROR(commonMemSetLargePattern(Stream, patternSize, size,
927+
pPattern, DstDevice));
932928
break;
933929
}
934930
}
935931

932+
UR_CHECK_ERROR(RetImplEvent->record());
936933
if (phEvent) {
937-
UR_CHECK_ERROR(RetImplEvent->record());
938934
*phEvent = RetImplEvent.release();
935+
hBuffer->setLastEventWritingToMemObj(*phEvent);
936+
} else {
937+
hBuffer->setLastEventWritingToMemObj(RetImplEvent.release());
939938
}
940-
941-
return Result;
942939
} catch (ur_result_t Err) {
943940
return Err;
944941
} catch (...) {
945942
return UR_RESULT_ERROR_UNKNOWN;
946943
}
944+
return UR_RESULT_SUCCESS;
947945
}
948946

949947
/// General ND memory copy operation for images (where N > 1).

0 commit comments

Comments
 (0)