Skip to content

Commit 7345640

Browse files
author
Hugh Delaney
committed
Make buffer migration only work if phEvent is non null
Buffer migration should only work across a context if the UR user passes in a non null event. Otherwise no event will be recorded. This simplifies code slightly.
1 parent 2e0ab53 commit 7345640

File tree

3 files changed

+96
-122
lines changed

3 files changed

+96
-122
lines changed

source/adapters/cuda/enqueue.cpp

Lines changed: 40 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -492,20 +492,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
492492
}
493493
}
494494

495-
if (phEvent || MemMigrationEvents.size()) {
495+
if (phEvent) {
496496
RetImplEvent =
497497
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
498498
UR_COMMAND_KERNEL_LAUNCH, hQueue, CuStream, StreamToken));
499499
UR_CHECK_ERROR(RetImplEvent->start());
500500
}
501501

502502
// Once event has been started we can unlock MemoryMigrationMutex
503-
if (hQueue->getContext()->Devices.size() > 1) {
503+
if (phEvent && hQueue->getContext()->Devices.size() > 1) {
504504
for (auto &MemArg : hKernel->Args.MemObjArgs) {
505505
// Telling the ur_mem_handle_t that it will need to wait on this kernel
506506
// if it has been written to
507-
if (phEvent && (MemArg.AccessFlags &
508-
(UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY))) {
507+
if (MemArg.AccessFlags &
508+
(UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
509509
MemArg.Mem->setLastEventWritingToMemObj(RetImplEvent.get());
510510
}
511511
}
@@ -525,17 +525,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
525525
if (phEvent) {
526526
UR_CHECK_ERROR(RetImplEvent->record());
527527
*phEvent = RetImplEvent.release();
528-
} else if (MemMigrationEvents.size()) {
529-
UR_CHECK_ERROR(RetImplEvent->record());
530-
for (auto &MemArg : hKernel->Args.MemObjArgs) {
531-
// If no event is passed to entry point, we still need to have an event
532-
// if ur_mem_handle_t s are used. Here we give ownership of the event
533-
// to the ur_mem_handle_t
534-
if (MemArg.AccessFlags &
535-
(UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
536-
MemArg.Mem->setLastEventWritingToMemObj(RetImplEvent.release());
537-
}
538-
}
539528
}
540529
} catch (ur_result_t Err) {
541530
return Err;
@@ -694,20 +683,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
694683
}
695684
}
696685

697-
if (phEvent || MemMigrationEvents.size()) {
686+
if (phEvent) {
698687
RetImplEvent =
699688
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
700689
UR_COMMAND_KERNEL_LAUNCH, hQueue, CuStream, StreamToken));
701690
UR_CHECK_ERROR(RetImplEvent->start());
702691
}
703692

704693
// Once event has been started we can unlock MemoryMigrationMutex
705-
if (hQueue->getContext()->Devices.size() > 1) {
694+
if (phEvent && hQueue->getContext()->Devices.size() > 1) {
706695
for (auto &MemArg : hKernel->Args.MemObjArgs) {
707696
// Telling the ur_mem_handle_t that it will need to wait on this kernel
708697
// if it has been written to
709-
if (phEvent && (MemArg.AccessFlags &
710-
(UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY))) {
698+
if (MemArg.AccessFlags &
699+
(UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
711700
MemArg.Mem->setLastEventWritingToMemObj(RetImplEvent.get());
712701
}
713702
}
@@ -740,19 +729,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
740729
if (phEvent) {
741730
UR_CHECK_ERROR(RetImplEvent->record());
742731
*phEvent = RetImplEvent.release();
743-
} else if (MemMigrationEvents.size()) {
744-
UR_CHECK_ERROR(RetImplEvent->record());
745-
for (auto &MemArg : hKernel->Args.MemObjArgs) {
746-
// If no event is passed to entry point, we still need to have an event
747-
// if ur_mem_handle_t s are used. Here we give ownership of the event
748-
// to the ur_mem_handle_t
749-
if (MemArg.AccessFlags &
750-
(UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) {
751-
MemArg.Mem->setLastEventWritingToMemObj(RetImplEvent.release());
752-
}
753-
}
754732
}
755-
756733
} catch (ur_result_t Err) {
757734
return Err;
758735
}
@@ -912,6 +889,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
912889
ur_event_handle_t *phEvent) {
913890
CUdeviceptr DevPtr =
914891
std::get<BufferMem>(hBuffer->Mem).getPtr(hQueue->getDevice());
892+
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
915893
ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex};
916894

917895
try {
@@ -920,18 +898,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
920898
UR_CHECK_ERROR(enqueueEventsWait(hQueue, cuStream, numEventsInWaitList,
921899
phEventWaitList));
922900

923-
// With multi dev ctx we have no choice but to record this event
924-
std::unique_ptr<ur_event_handle_t_> RetImplEvent =
925-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
926-
UR_COMMAND_MEM_BUFFER_WRITE_RECT, hQueue, cuStream));
927-
UR_CHECK_ERROR(RetImplEvent->start());
901+
if (phEvent) {
902+
RetImplEvent =
903+
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
904+
UR_COMMAND_MEM_BUFFER_WRITE_RECT, hQueue, cuStream));
905+
UR_CHECK_ERROR(RetImplEvent->start());
906+
}
928907

929908
UR_CHECK_ERROR(commonEnqueueMemBufferCopyRect(
930909
cuStream, region, pSrc, CU_MEMORYTYPE_HOST, hostOrigin, hostRowPitch,
931910
hostSlicePitch, &DevPtr, CU_MEMORYTYPE_DEVICE, bufferOrigin,
932911
bufferRowPitch, bufferSlicePitch));
933912

934-
UR_CHECK_ERROR(RetImplEvent->record());
913+
if (phEvent) {
914+
UR_CHECK_ERROR(RetImplEvent->record());
915+
}
935916

936917
if (blockingWrite) {
937918
UR_CHECK_ERROR(cuStreamSynchronize(cuStream));
@@ -940,10 +921,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
940921
if (phEvent) {
941922
*phEvent = RetImplEvent.release();
942923
hBuffer->setLastEventWritingToMemObj(*phEvent);
943-
} else {
944-
hBuffer->setLastEventWritingToMemObj(RetImplEvent.release());
945924
}
946-
947925
} catch (ur_result_t Err) {
948926
return Err;
949927
}
@@ -1081,6 +1059,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
10811059
ur_event_handle_t *phEvent) {
10821060
UR_ASSERT(size + offset <= std::get<BufferMem>(hBuffer->Mem).getSize(),
10831061
UR_RESULT_ERROR_INVALID_SIZE);
1062+
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
10841063
ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex};
10851064

10861065
try {
@@ -1090,11 +1069,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
10901069
UR_CHECK_ERROR(enqueueEventsWait(hQueue, Stream, numEventsInWaitList,
10911070
phEventWaitList));
10921071

1093-
// With multi dev ctx we have no choice but to record this event
1094-
std::unique_ptr<ur_event_handle_t_> RetImplEvent =
1095-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
1096-
UR_COMMAND_MEM_BUFFER_FILL, hQueue, Stream));
1097-
UR_CHECK_ERROR(RetImplEvent->start());
1072+
if (phEvent) {
1073+
RetImplEvent =
1074+
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
1075+
UR_COMMAND_MEM_BUFFER_WRITE_RECT, hQueue, Stream));
1076+
UR_CHECK_ERROR(RetImplEvent->start());
1077+
}
10981078

10991079
auto DstDevice = std::get<BufferMem>(hBuffer->Mem)
11001080
.getPtrWithOffset(hQueue->getDevice(), offset);
@@ -1124,13 +1104,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
11241104
}
11251105
}
11261106

1127-
UR_CHECK_ERROR(RetImplEvent->record());
11281107
if (phEvent) {
1108+
UR_CHECK_ERROR(RetImplEvent->record());
11291109
*phEvent = RetImplEvent.release();
11301110
hBuffer->setLastEventWritingToMemObj(*phEvent);
1131-
} else {
1132-
// Give buffer ownership if no event used
1133-
hBuffer->setLastEventWritingToMemObj(RetImplEvent.release());
11341111
}
11351112
} catch (ur_result_t Err) {
11361113
return Err;
@@ -1925,27 +1902,30 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
19251902
UR_ASSERT(offset + size <= std::get<BufferMem>(hBuffer->Mem).Size,
19261903
UR_RESULT_ERROR_INVALID_SIZE);
19271904

1928-
ur_result_t Result = UR_RESULT_SUCCESS;
19291905
CUdeviceptr DevPtr =
19301906
std::get<BufferMem>(hBuffer->Mem).getPtr(hQueue->getDevice());
1907+
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
19311908
ur_lock MemMigrationLock{hBuffer->MemoryMigrationMutex};
19321909

19331910
try {
19341911
ScopedContext Active(hQueue->getDevice());
19351912
CUstream CuStream = hQueue->getNextTransferStream();
19361913

1937-
Result = enqueueEventsWait(hQueue, CuStream, numEventsInWaitList,
1938-
phEventWaitList);
1914+
UR_CHECK_ERROR(enqueueEventsWait(hQueue, CuStream, numEventsInWaitList,
1915+
phEventWaitList));
19391916

1940-
// With multi dev ctx we have no choice but to record this event
1941-
std::unique_ptr<ur_event_handle_t_> RetImplEvent =
1942-
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
1943-
UR_COMMAND_MEM_BUFFER_WRITE, hQueue, CuStream));
1944-
UR_CHECK_ERROR(RetImplEvent->start());
1917+
if (phEvent) {
1918+
RetImplEvent =
1919+
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
1920+
UR_COMMAND_MEM_BUFFER_WRITE, hQueue, CuStream));
1921+
UR_CHECK_ERROR(RetImplEvent->start());
1922+
}
19451923

19461924
UR_CHECK_ERROR(cuMemcpyHtoDAsync(DevPtr + offset, pSrc, size, CuStream));
19471925

1948-
UR_CHECK_ERROR(RetImplEvent->record());
1926+
if (phEvent) {
1927+
UR_CHECK_ERROR(RetImplEvent->record());
1928+
}
19491929

19501930
if (blockingWrite) {
19511931
UR_CHECK_ERROR(cuStreamSynchronize(CuStream));
@@ -1954,14 +1934,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
19541934
if (phEvent) {
19551935
*phEvent = RetImplEvent.release();
19561936
hBuffer->setLastEventWritingToMemObj(*phEvent);
1957-
} else {
1958-
// Give buffer ownership if no event used
1959-
hBuffer->setLastEventWritingToMemObj(RetImplEvent.release());
19601937
}
19611938
} catch (ur_result_t Err) {
1962-
Result = Err;
1939+
return Err;
19631940
}
1964-
return Result;
1941+
return UR_RESULT_SUCCESS;
19651942
}
19661943

19671944
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite(

0 commit comments

Comments
 (0)