Skip to content

Commit 7142006

Browse files
author
Hugh Delaney
committed
Don't assume context contains all devices in platform
When indexing into Ptrs which are per ctx, don't use the per platform device index.
1 parent ae74ef4 commit 7142006

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

source/adapters/cuda/memory.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ ur_result_t allocateMemObjOnDeviceIfNeeded(ur_mem_handle_t Mem,
433433

434434
if (Mem->isBuffer()) {
435435
auto &Buffer = std::get<BufferMem>(Mem->Mem);
436-
auto &DevPtr = Buffer.Ptrs[hDevice->getIndex()];
436+
auto &DevPtr = Buffer.Ptrs[hDevice->getIndex() % Buffer.Ptrs.size()];
437437

438438
// Allocation has already been made
439439
if (DevPtr != BufferMem::native_type{0}) {
@@ -456,11 +456,11 @@ ur_result_t allocateMemObjOnDeviceIfNeeded(ur_mem_handle_t Mem,
456456
try {
457457
auto &Image = std::get<SurfaceMem>(Mem->Mem);
458458
// Allocation has already been made
459-
if (Image.Arrays[hDevice->getIndex()]) {
459+
if (Image.Arrays[hDevice->getIndex() % Image.Arrays.size()]) {
460460
return UR_RESULT_SUCCESS;
461461
}
462462
UR_CHECK_ERROR(cuArray3DCreate(&ImageArray, &Image.ArrayDesc));
463-
Image.Arrays[hDevice->getIndex()] = ImageArray;
463+
Image.Arrays[hDevice->getIndex() % Image.Arrays.size()] = ImageArray;
464464

465465
// CUDA_RESOURCE_DESC is a union of different structs, shown here
466466
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TEXOBJECT.html
@@ -475,7 +475,7 @@ ur_result_t allocateMemObjOnDeviceIfNeeded(ur_mem_handle_t Mem,
475475
ImageResDesc.flags = 0;
476476

477477
UR_CHECK_ERROR(cuSurfObjectCreate(&Surface, &ImageResDesc));
478-
Image.SurfObjs[hDevice->getIndex()] = Surface;
478+
Image.SurfObjs[hDevice->getIndex() % Image.SurfObjs.size()] = Surface;
479479
} catch (ur_result_t Err) {
480480
if (ImageArray) {
481481
UR_CHECK_ERROR(cuArrayDestroy(ImageArray));
@@ -590,7 +590,9 @@ ur_result_t migrateMemoryToDeviceIfNeeded(ur_mem_handle_t Mem,
590590
UR_ASSERT(hDevice, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
591591
// Device allocation has already been initialized with most up to date
592592
// data in buffer
593-
if (Mem->HaveMigratedToDeviceSinceLastWrite[hDevice->getIndex()]) {
593+
if (Mem->HaveMigratedToDeviceSinceLastWrite
594+
[hDevice->getIndex() %
595+
Mem->HaveMigratedToDeviceSinceLastWrite.size()]) {
594596
return UR_RESULT_SUCCESS;
595597
}
596598

@@ -601,6 +603,8 @@ ur_result_t migrateMemoryToDeviceIfNeeded(ur_mem_handle_t Mem,
601603
UR_CHECK_ERROR(migrateImageToDevice(Mem, hDevice));
602604
}
603605

604-
Mem->HaveMigratedToDeviceSinceLastWrite[hDevice->getIndex()] = true;
606+
Mem->HaveMigratedToDeviceSinceLastWrite
607+
[hDevice->getIndex() % Mem->HaveMigratedToDeviceSinceLastWrite.size()] =
608+
true;
605609
return UR_RESULT_SUCCESS;
606610
}

source/adapters/cuda/memory.hpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ struct BufferMem {
104104
throw Err;
105105
}
106106
return reinterpret_cast<native_type>(
107-
reinterpret_cast<uint8_t *>(Ptrs[Device->getIndex()]) + Offset);
107+
reinterpret_cast<uint8_t *>(Ptrs[Device->getIndex() % Ptrs.size()]) +
108+
Offset);
108109
}
109110

110111
native_type getPtr(const ur_device_handle_t Device) {
@@ -274,7 +275,7 @@ struct SurfaceMem {
274275
Err != UR_RESULT_SUCCESS) {
275276
throw Err;
276277
}
277-
return Arrays[Device->getIndex()];
278+
return Arrays[Device->getIndex() % Arrays.size()];
278279
}
279280
// Will allocate a new surface on device if not already allocated
280281
CUsurfObject getSurface(const ur_device_handle_t Device) {
@@ -283,7 +284,7 @@ struct SurfaceMem {
283284
Err != UR_RESULT_SUCCESS) {
284285
throw Err;
285286
}
286-
return SurfObjs[Device->getIndex()];
287+
return SurfObjs[Device->getIndex() % SurfObjs.size()];
287288
}
288289

289290
ur_mem_type_t getType() { return ImageDesc.type; }
@@ -514,8 +515,9 @@ struct ur_mem_handle_t_ {
514515
for (const auto &Device : Context->getDevices()) {
515516
// This event is never an interop event so will always have an associated
516517
// queue
517-
HaveMigratedToDeviceSinceLastWrite[Device->getIndex()] =
518-
Device == NewEvent->getQueue()->getDevice();
518+
HaveMigratedToDeviceSinceLastWrite
519+
[Device->getIndex() % HaveMigratedToDeviceSinceLastWrite.size()] =
520+
Device == NewEvent->getQueue()->getDevice();
519521
}
520522
}
521523
};

0 commit comments

Comments
 (0)