Skip to content

Commit 49c6e6d

Browse files
committed
Make consistent with cuda multidev patch.
Signed-off-by: JackAKirk <[email protected]>
1 parent f22e096 commit 49c6e6d

File tree

1 file changed

+69
-13
lines changed

1 file changed

+69
-13
lines changed

source/adapters/cuda/enqueue.cpp

Lines changed: 69 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
499499
}
500500

501501
// Preconditions
502-
UR_ASSERT(hQueue->getContext() == hKernel->getContext(),
502+
UR_ASSERT(hQueue->getDevice() == hKernel->getProgram()->getDevice(),
503503
UR_RESULT_ERROR_INVALID_KERNEL);
504504
UR_ASSERT(workDim > 0, UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
505505
UR_ASSERT(workDim < 4, UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
@@ -538,6 +538,38 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
538538
}
539539
}
540540

541+
std::vector<ur_event_handle_t> DepEvents(
542+
phEventWaitList, phEventWaitList + numEventsInWaitList);
543+
std::vector<std::pair<ur_mem_handle_t, ur_lock>> MemMigrationLocks;
544+
545+
// phEventWaitList only contains events that are handed to UR by the SYCL
546+
// runtime. However since UR handles memory dependencies within a context
547+
// we may need to add more events to our dependent events list if the UR
548+
// context contains multiple devices
549+
if (hQueue->getContext()->Devices.size() > 1) {
550+
MemMigrationLocks.reserve(hKernel->Args.MemObjArgs.size());
551+
for (auto &MemArg : hKernel->Args.MemObjArgs) {
552+
bool PushBack = false;
553+
if (auto MemDepEvent = MemArg.Mem->LastEventWritingToMemObj;
554+
MemDepEvent && std::find(DepEvents.begin(), DepEvents.end(),
555+
MemDepEvent) == DepEvents.end()) {
556+
DepEvents.push_back(MemDepEvent);
557+
PushBack = true;
558+
}
559+
if ((MemArg.AccessFlags &
560+
(UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)) ||
561+
PushBack) {
562+
if (std::find_if(MemMigrationLocks.begin(), MemMigrationLocks.end(),
563+
[MemArg](auto &Lock) {
564+
return Lock.first == MemArg.Mem;
565+
}) == MemMigrationLocks.end())
566+
MemMigrationLocks.emplace_back(
567+
std::pair{MemArg.Mem, ur_lock{MemArg.Mem->MemoryMigrationMutex}});
568+
}
569+
}
570+
}
571+
572+
// Early exit for zero size kernel
541573
if (*pGlobalWorkSize == 0) {
542574
return urEnqueueEventsWaitWithBarrier(hQueue, numEventsInWaitList,
543575
phEventWaitList, phEvent);
@@ -549,26 +581,37 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
549581
size_t BlocksPerGrid[3] = {1u, 1u, 1u};
550582

551583
uint32_t LocalSize = hKernel->getLocalSize();
552-
ur_result_t Result = UR_RESULT_SUCCESS;
553584
CUfunction CuFunc = hKernel->get();
554585

555-
Result = setKernelParams(hQueue->getContext(), hQueue->Device, workDim,
556-
nullptr, pGlobalWorkSize, pLocalWorkSize, hKernel,
557-
CuFunc, ThreadsPerBlock, BlocksPerGrid);
558-
if (Result != UR_RESULT_SUCCESS) {
559-
return Result;
560-
}
586+
// This might return UR_RESULT_ERROR_ADAPTER_SPECIFIC, which cannot be handled
587+
// using the standard UR_CHECK_ERROR
588+
if (ur_result_t Ret =
589+
setKernelParams(hQueue->getContext(), hQueue->Device, workDim,
590+
nullptr, pGlobalWorkSize, pLocalWorkSize,
591+
hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
592+
Ret != UR_RESULT_SUCCESS)
593+
return Ret;
561594

562595
try {
563596
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
564597

598+
ScopedContext Active(hQueue->getDevice());
565599
uint32_t StreamToken;
566600
ur_stream_guard_ Guard;
567601
CUstream CuStream = hQueue->getNextComputeStream(
568602
numEventsInWaitList, phEventWaitList, Guard, &StreamToken);
569603

570-
Result = enqueueEventsWait(hQueue, CuStream, numEventsInWaitList,
571-
phEventWaitList);
604+
if (DepEvents.size()) {
605+
UR_CHECK_ERROR(enqueueEventsWait(hQueue, CuStream, DepEvents.size(),
606+
DepEvents.data()));
607+
}
608+
609+
// For memory migration across devices in the same context
610+
if (hQueue->getContext()->Devices.size() > 1) {
611+
for (auto &MemArg : hKernel->Args.MemObjArgs) {
612+
migrateMemoryToDeviceIfNeeded(MemArg.Mem, hQueue->getDevice());
613+
}
614+
}
572615

573616
if (phEvent) {
574617
RetImplEvent =
@@ -577,6 +620,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
577620
UR_CHECK_ERROR(RetImplEvent->start());
578621
}
579622

623+
// Once event has been started we can unlock MemoryMigrationMutex
624+
if (hQueue->getContext()->Devices.size() > 1) {
625+
for (auto &MemArg : hKernel->Args.MemObjArgs) {
626+
// Telling the ur_mem_handle_t that it will need to wait on this kernel
627+
// if it has been written to
628+
if (phEvent && (MemArg.AccessFlags &
629+
(UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY))) {
630+
MemArg.Mem->setLastEventWritingToMemObj(RetImplEvent.get());
631+
}
632+
}
633+
// We can release the MemoryMigrationMutexes now
634+
MemMigrationLocks.clear();
635+
}
636+
580637
auto &ArgIndices = hKernel->getArgIndices();
581638

582639
CUlaunchConfig launch_config;
@@ -605,12 +662,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
605662
}
606663

607664
} catch (ur_result_t Err) {
608-
Result = Err;
665+
return Err;
609666
}
610-
return Result;
667+
return UR_RESULT_SUCCESS;
611668
}
612669

613-
614670
/// Set parameters for general 3D memory copy.
615671
/// If the source and/or destination is on the device, SrcPtr and/or DstPtr
616672
/// must be a pointer to a CUdeviceptr

0 commit comments

Comments
 (0)