Skip to content

Commit d3502dc

Browse files
authored
Merge pull request #1565 from hdelan/cuda-multi-dev-ctx
[CUDA] CUDA adapter multi device context
2 parents c911a9b + 7142006 commit d3502dc

27 files changed

+1157
-706
lines changed

source/adapters/cuda/command_buffer.cpp

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,10 @@ static ur_result_t enqueueCommandBufferFillHelper(
203203
}
204204
}
205205

206-
UR_CHECK_ERROR(cuGraphAddMemsetNode(
207-
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(),
208-
DepsList.size(), &NodeParams, CommandBuffer->Device->getContext()));
206+
UR_CHECK_ERROR(
207+
cuGraphAddMemsetNode(&GraphNode, CommandBuffer->CudaGraph,
208+
DepsList.data(), DepsList.size(), &NodeParams,
209+
CommandBuffer->Device->getNativeContext()));
209210

210211
// Get sync point and register the cuNode with it.
211212
*SyncPoint =
@@ -237,7 +238,7 @@ static ur_result_t enqueueCommandBufferFillHelper(
237238
UR_CHECK_ERROR(cuGraphAddMemsetNode(
238239
&GraphNodeFirst, CommandBuffer->CudaGraph, DepsList.data(),
239240
DepsList.size(), &NodeParamsStepFirst,
240-
CommandBuffer->Device->getContext()));
241+
CommandBuffer->Device->getNativeContext()));
241242

242243
// Get sync point and register the cuNode with it.
243244
*SyncPoint = CommandBuffer->addSyncPoint(
@@ -269,7 +270,7 @@ static ur_result_t enqueueCommandBufferFillHelper(
269270
UR_CHECK_ERROR(cuGraphAddMemsetNode(
270271
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(),
271272
DepsList.size(), &NodeParamsStep,
272-
CommandBuffer->Device->getContext()));
273+
CommandBuffer->Device->getNativeContext()));
273274

274275
GraphNodePtr = std::make_shared<CUgraphNode>(GraphNode);
275276
// Get sync point and register the cuNode with it.
@@ -478,7 +479,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMMemcpyExp(
478479

479480
UR_CHECK_ERROR(cuGraphAddMemcpyNode(
480481
&GraphNode, hCommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),
481-
&NodeParams, hCommandBuffer->Device->getContext()));
482+
&NodeParams, hCommandBuffer->Device->getNativeContext()));
482483

483484
// Get sync point and register the cuNode with it.
484485
*pSyncPoint =
@@ -513,16 +514,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp(
513514
}
514515

515516
try {
516-
auto Src = std::get<BufferMem>(hSrcMem->Mem).get() + srcOffset;
517-
auto Dst = std::get<BufferMem>(hDstMem->Mem).get() + dstOffset;
517+
auto Src = std::get<BufferMem>(hSrcMem->Mem)
518+
.getPtrWithOffset(hCommandBuffer->Device, srcOffset);
519+
auto Dst = std::get<BufferMem>(hDstMem->Mem)
520+
.getPtrWithOffset(hCommandBuffer->Device, dstOffset);
518521

519522
CUDA_MEMCPY3D NodeParams = {};
520523
setCopyParams(&Src, CU_MEMORYTYPE_DEVICE, &Dst, CU_MEMORYTYPE_DEVICE, size,
521524
NodeParams);
522525

523526
UR_CHECK_ERROR(cuGraphAddMemcpyNode(
524527
&GraphNode, hCommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),
525-
&NodeParams, hCommandBuffer->Device->getContext()));
528+
&NodeParams, hCommandBuffer->Device->getNativeContext()));
526529

527530
// Get sync point and register the cuNode with it.
528531
*pSyncPoint =
@@ -553,8 +556,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(
553556
}
554557

555558
try {
556-
CUdeviceptr SrcPtr = std::get<BufferMem>(hSrcMem->Mem).get();
557-
CUdeviceptr DstPtr = std::get<BufferMem>(hDstMem->Mem).get();
559+
auto SrcPtr =
560+
std::get<BufferMem>(hSrcMem->Mem).getPtr(hCommandBuffer->Device);
561+
auto DstPtr =
562+
std::get<BufferMem>(hDstMem->Mem).getPtr(hCommandBuffer->Device);
558563
CUDA_MEMCPY3D NodeParams = {};
559564

560565
setCopyRectParams(region, &SrcPtr, CU_MEMORYTYPE_DEVICE, srcOrigin,
@@ -563,7 +568,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(
563568

564569
UR_CHECK_ERROR(cuGraphAddMemcpyNode(
565570
&GraphNode, hCommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),
566-
&NodeParams, hCommandBuffer->Device->getContext()));
571+
&NodeParams, hCommandBuffer->Device->getNativeContext()));
567572

568573
// Get sync point and register the cuNode with it.
569574
*pSyncPoint =
@@ -593,15 +598,16 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteExp(
593598
}
594599

595600
try {
596-
auto Dst = std::get<BufferMem>(hBuffer->Mem).get() + offset;
601+
auto Dst = std::get<BufferMem>(hBuffer->Mem)
602+
.getPtrWithOffset(hCommandBuffer->Device, offset);
597603

598604
CUDA_MEMCPY3D NodeParams = {};
599605
setCopyParams(pSrc, CU_MEMORYTYPE_HOST, &Dst, CU_MEMORYTYPE_DEVICE, size,
600606
NodeParams);
601607

602608
UR_CHECK_ERROR(cuGraphAddMemcpyNode(
603609
&GraphNode, hCommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),
604-
&NodeParams, hCommandBuffer->Device->getContext()));
610+
&NodeParams, hCommandBuffer->Device->getNativeContext()));
605611

606612
// Get sync point and register the cuNode with it.
607613
*pSyncPoint =
@@ -630,15 +636,16 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadExp(
630636
}
631637

632638
try {
633-
auto Src = std::get<BufferMem>(hBuffer->Mem).get() + offset;
639+
auto Src = std::get<BufferMem>(hBuffer->Mem)
640+
.getPtrWithOffset(hCommandBuffer->Device, offset);
634641

635642
CUDA_MEMCPY3D NodeParams = {};
636643
setCopyParams(&Src, CU_MEMORYTYPE_DEVICE, pDst, CU_MEMORYTYPE_HOST, size,
637644
NodeParams);
638645

639646
UR_CHECK_ERROR(cuGraphAddMemcpyNode(
640647
&GraphNode, hCommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),
641-
&NodeParams, hCommandBuffer->Device->getContext()));
648+
&NodeParams, hCommandBuffer->Device->getNativeContext()));
642649

643650
// Get sync point and register the cuNode with it.
644651
*pSyncPoint =
@@ -670,7 +677,8 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp(
670677
}
671678

672679
try {
673-
CUdeviceptr DstPtr = std::get<BufferMem>(hBuffer->Mem).get();
680+
auto DstPtr =
681+
std::get<BufferMem>(hBuffer->Mem).getPtr(hCommandBuffer->Device);
674682
CUDA_MEMCPY3D NodeParams = {};
675683

676684
setCopyRectParams(region, pSrc, CU_MEMORYTYPE_HOST, hostOffset,
@@ -680,7 +688,7 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp(
680688

681689
UR_CHECK_ERROR(cuGraphAddMemcpyNode(
682690
&GraphNode, hCommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),
683-
&NodeParams, hCommandBuffer->Device->getContext()));
691+
&NodeParams, hCommandBuffer->Device->getNativeContext()));
684692

685693
// Get sync point and register the cuNode with it.
686694
*pSyncPoint =
@@ -712,7 +720,8 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
712720
}
713721

714722
try {
715-
CUdeviceptr SrcPtr = std::get<BufferMem>(hBuffer->Mem).get();
723+
auto SrcPtr =
724+
std::get<BufferMem>(hBuffer->Mem).getPtr(hCommandBuffer->Device);
716725
CUDA_MEMCPY3D NodeParams = {};
717726

718727
setCopyRectParams(region, &SrcPtr, CU_MEMORYTYPE_DEVICE, bufferOffset,
@@ -722,7 +731,7 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
722731

723732
UR_CHECK_ERROR(cuGraphAddMemcpyNode(
724733
&GraphNode, hCommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),
725-
&NodeParams, hCommandBuffer->Device->getContext()));
734+
&NodeParams, hCommandBuffer->Device->getNativeContext()));
726735

727736
// Get sync point and register the cuNode with it.
728737
*pSyncPoint =
@@ -821,7 +830,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
821830
PatternSizeIsValid,
822831
UR_RESULT_ERROR_INVALID_SIZE);
823832

824-
auto DstDevice = std::get<BufferMem>(hBuffer->Mem).get() + offset;
833+
auto DstDevice = std::get<BufferMem>(hBuffer->Mem)
834+
.getPtrWithOffset(hCommandBuffer->Device, offset);
825835

826836
return enqueueCommandBufferFillHelper(
827837
hCommandBuffer, &DstDevice, CU_MEMORYTYPE_DEVICE, pPattern, patternSize,
@@ -854,7 +864,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
854864

855865
try {
856866
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
857-
ScopedContext Active(hQueue->getContext());
867+
ScopedContext Active(hQueue->getDevice());
858868
uint32_t StreamToken;
859869
ur_stream_guard_ Guard;
860870
CUstream CuStream = hQueue->getNextComputeStream(
@@ -972,7 +982,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
972982
if (ArgValue == nullptr) {
973983
Kernel->setKernelArg(ArgIndex, 0, nullptr);
974984
} else {
975-
CUdeviceptr CuPtr = std::get<BufferMem>(ArgValue->Mem).get();
985+
CUdeviceptr CuPtr =
986+
std::get<BufferMem>(ArgValue->Mem).getPtr(CommandBuffer->Device);
976987
Kernel->setKernelArg(ArgIndex, sizeof(CUdeviceptr), (void *)&CuPtr);
977988
}
978989
} catch (ur_result_t Err) {

source/adapters/cuda/context.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,23 +46,19 @@ UR_APIEXPORT ur_result_t UR_APICALL
4646
urContextCreate(uint32_t DeviceCount, const ur_device_handle_t *phDevices,
4747
const ur_context_properties_t *pProperties,
4848
ur_context_handle_t *phContext) {
49-
std::ignore = DeviceCount;
5049
std::ignore = pProperties;
5150

52-
assert(DeviceCount == 1);
53-
ur_result_t RetErr = UR_RESULT_SUCCESS;
54-
5551
std::unique_ptr<ur_context_handle_t_> ContextPtr{nullptr};
5652
try {
5753
ContextPtr = std::unique_ptr<ur_context_handle_t_>(
58-
new ur_context_handle_t_{*phDevices});
54+
new ur_context_handle_t_{phDevices, DeviceCount});
5955
*phContext = ContextPtr.release();
6056
} catch (ur_result_t Err) {
61-
RetErr = Err;
57+
return Err;
6258
} catch (...) {
63-
RetErr = UR_RESULT_ERROR_OUT_OF_RESOURCES;
59+
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
6460
}
65-
return RetErr;
61+
return UR_RESULT_SUCCESS;
6662
}
6763

6864
UR_APIEXPORT ur_result_t UR_APICALL urContextGetInfo(
@@ -72,9 +68,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetInfo(
7268

7369
switch (static_cast<uint32_t>(ContextInfoType)) {
7470
case UR_CONTEXT_INFO_NUM_DEVICES:
75-
return ReturnValue(1);
71+
return ReturnValue(static_cast<uint32_t>(hContext->getDevices().size()));
7672
case UR_CONTEXT_INFO_DEVICES:
77-
return ReturnValue(hContext->getDevice());
73+
return ReturnValue(hContext->getDevices().data(),
74+
hContext->getDevices().size());
7875
case UR_CONTEXT_INFO_REFERENCE_COUNT:
7976
return ReturnValue(hContext->getReferenceCount());
8077
case UR_CONTEXT_INFO_ATOMIC_MEMORY_ORDER_CAPABILITIES: {
@@ -88,7 +85,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetInfo(
8885
int Major = 0;
8986
UR_CHECK_ERROR(cuDeviceGetAttribute(
9087
&Major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
91-
hContext->getDevice()->get()));
88+
hContext->getDevices()[0]->get()));
9289
uint32_t Capabilities =
9390
(Major >= 7) ? UR_MEMORY_SCOPE_CAPABILITY_FLAG_WORK_ITEM |
9491
UR_MEMORY_SCOPE_CAPABILITY_FLAG_SUB_GROUP |
@@ -137,7 +134,10 @@ urContextRetain(ur_context_handle_t hContext) {
137134

138135
UR_APIEXPORT ur_result_t UR_APICALL urContextGetNativeHandle(
139136
ur_context_handle_t hContext, ur_native_handle_t *phNativeContext) {
140-
*phNativeContext = reinterpret_cast<ur_native_handle_t>(hContext->get());
137+
// FIXME: this entry point has been deprecated in the SYCL RT and should be
138+
// changed to unsupoorted once deprecation period has elapsed.
139+
*phNativeContext = reinterpret_cast<ur_native_handle_t>(
140+
hContext->getDevices()[0]->getNativeContext());
141141
return UR_RESULT_SUCCESS;
142142
}
143143

source/adapters/cuda/context.hpp

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,27 +33,26 @@ typedef void (*ur_context_extended_deleter_t)(void *user_data);
3333
///
3434
/// One of the main differences between the UR API and the CUDA driver API is
3535
/// that the second modifies the state of the threads by assigning
36-
/// `CUcontext` objects to threads. `CUcontext` objects store data associated
36+
/// \c CUcontext objects to threads. \c CUcontext objects store data associated
3737
/// with a given device and control access to said device from the user side.
3838
/// UR API context are objects that are passed to functions, and not bound
3939
/// to threads.
40-
/// The ur_context_handle_t_ object doesn't implement this behavior. It only
41-
/// holds the CUDA context data. The RAII object \ref ScopedContext implements
42-
/// the active context behavior.
4340
///
44-
/// <b> Primary vs User-defined context </b>
41+
/// Since the \c ur_context_handle_t can contain multiple devices, and a \c
42+
/// CUcontext refers to only a single device, the \c CUcontext is more tightly
43+
/// coupled to a \c ur_device_handle_t than a \c ur_context_handle_t. In order
44+
/// to remove some ambiguities about the different semantics of \c
45+
/// \c ur_context_handle_t and native \c CUcontext, we access the native \c
46+
/// CUcontext solely through the \c ur_device_handle_t class, by using the
47+
/// object \ref ScopedContext, which sets the active device (by setting the
48+
/// active native \c CUcontext).
4549
///
46-
/// CUDA has two different types of context, the Primary context,
47-
/// which is usable by all threads on a given process for a given device, and
48-
/// the aforementioned custom contexts.
49-
/// The CUDA documentation, confirmed with performance analysis, suggest using
50-
/// the Primary context whenever possible.
51-
/// The Primary context is also used by the CUDA Runtime API.
52-
/// For UR applications to interop with CUDA Runtime API, they have to use
53-
/// the primary context - and make that active in the thread.
54-
/// The `ur_context_handle_t_` object can be constructed with a `kind` parameter
55-
/// that allows to construct a Primary or `user-defined` context, so that
56-
/// the UR object interface is always the same.
50+
/// <b> Primary vs User-defined \c CUcontext </b>
51+
///
52+
/// CUDA has two different types of \c CUcontext, the Primary context, which is
53+
/// usable by all threads on a given process for a given device, and the
54+
/// aforementioned custom \c CUcontext s. The CUDA documentation, confirmed with
55+
/// performance analysis, suggest using the Primary context whenever possible.
5756
///
5857
/// <b> Destructor callback </b>
5958
///
@@ -63,6 +62,18 @@ typedef void (*ur_context_extended_deleter_t)(void *user_data);
6362
/// See proposal for details.
6463
/// https://github.com/codeplaysoftware/standards-proposals/blob/master/extended-context-destruction/index.md
6564
///
65+
///
66+
/// <b> Memory Management for Devices in a Context <\b>
67+
///
68+
/// A \c ur_mem_handle_t is associated with a \c ur_context_handle_t_, which
69+
/// may refer to multiple devices. Therefore the \c ur_mem_handle_t must
70+
/// handle a native allocation for each device in the context. UR is
71+
/// responsible for automatically handling event dependencies for kernels
72+
/// writing to or reading from the same \c ur_mem_handle_t and migrating memory
73+
/// between native allocations for devices in the same \c ur_context_handle_t_
74+
/// if necessary.
75+
///
76+
///
6677
struct ur_context_handle_t_ {
6778

6879
struct deleter_data {
@@ -72,18 +83,21 @@ struct ur_context_handle_t_ {
7283
void operator()() { Function(UserData); }
7384
};
7485

75-
using native_type = CUcontext;
76-
77-
native_type CUContext;
78-
ur_device_handle_t DeviceID;
86+
std::vector<ur_device_handle_t> Devices;
7987
std::atomic_uint32_t RefCount;
8088

81-
ur_context_handle_t_(ur_device_handle_t_ *DevID)
82-
: CUContext{DevID->getContext()}, DeviceID{DevID}, RefCount{1} {
83-
urDeviceRetain(DeviceID);
89+
ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices)
90+
: Devices{Devs, Devs + NumDevices}, RefCount{1} {
91+
for (auto &Dev : Devices) {
92+
urDeviceRetain(Dev);
93+
}
8494
};
8595

86-
~ur_context_handle_t_() { urDeviceRelease(DeviceID); }
96+
~ur_context_handle_t_() {
97+
for (auto &Dev : Devices) {
98+
urDeviceRelease(Dev);
99+
}
100+
}
87101

88102
void invokeExtendedDeleters() {
89103
std::lock_guard<std::mutex> Guard(Mutex);
@@ -98,9 +112,9 @@ struct ur_context_handle_t_ {
98112
ExtendedDeleters.emplace_back(deleter_data{Function, UserData});
99113
}
100114

101-
ur_device_handle_t getDevice() const noexcept { return DeviceID; }
102-
103-
native_type get() const noexcept { return CUContext; }
115+
const std::vector<ur_device_handle_t> &getDevices() const noexcept {
116+
return Devices;
117+
}
104118

105119
uint32_t incrementReferenceCount() noexcept { return ++RefCount; }
106120

@@ -123,12 +137,11 @@ struct ur_context_handle_t_ {
123137
namespace {
124138
class ScopedContext {
125139
public:
126-
ScopedContext(ur_context_handle_t Context) {
127-
if (!Context) {
128-
throw UR_RESULT_ERROR_INVALID_CONTEXT;
140+
ScopedContext(ur_device_handle_t Device) {
141+
if (!Device) {
142+
throw UR_RESULT_ERROR_INVALID_DEVICE;
129143
}
130-
131-
setContext(Context->get());
144+
setContext(Device->getNativeContext());
132145
}
133146

134147
ScopedContext(CUcontext NativeContext) { setContext(NativeContext); }

source/adapters/cuda/device.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
4747

4848
static constexpr uint32_t MaxWorkItemDimensions = 3u;
4949

50-
ScopedContext Active(hDevice->getContext());
50+
ScopedContext Active(hDevice);
5151

5252
switch ((uint32_t)propName) {
5353
case UR_DEVICE_INFO_TYPE: {
@@ -1234,7 +1234,7 @@ ur_result_t UR_APICALL urDeviceGetGlobalTimestamps(ur_device_handle_t hDevice,
12341234
uint64_t *pDeviceTimestamp,
12351235
uint64_t *pHostTimestamp) {
12361236
CUevent Event;
1237-
ScopedContext Active(hDevice->getContext());
1237+
ScopedContext Active(hDevice);
12381238

12391239
if (pDeviceTimestamp) {
12401240
UR_CHECK_ERROR(cuEventCreate(&Event, CU_EVENT_DEFAULT));

0 commit comments

Comments
 (0)