Skip to content

Commit c0da06e

Browse files
author
Hugh Delaney
committed
Store HostPtr in ur mem
We might want to write this mem into an image on another device.
1 parent 24b3336 commit c0da06e

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

source/adapters/cuda/enqueue.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
10981098
std::ignore = slicePitch;
10991099

11001100
UR_ASSERT(hImage->isImage(), UR_RESULT_ERROR_INVALID_MEM_OBJECT);
1101+
auto &Image = std::get<SurfaceMem>(hImage->Mem);
1102+
// FIXME: We are assuming that the lifetime of host ptr lives as long as the
1103+
// image
1104+
if (!Image.HostPtr)
1105+
Image.HostPtr = pSrc;
11011106

11021107
ur_result_t Result = UR_RESULT_SUCCESS;
11031108

@@ -1107,8 +1112,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
11071112
Result = enqueueEventsWait(hQueue, CuStream, numEventsInWaitList,
11081113
phEventWaitList);
11091114

1110-
CUarray Array =
1111-
std::get<SurfaceMem>(hImage->Mem).getArray(hQueue->getDevice());
1115+
CUarray Array = Image.getArray(hQueue->getDevice());
11121116

11131117
CUDA_ARRAY_DESCRIPTOR ArrayDesc;
11141118
UR_CHECK_ERROR(cuArrayGetDescriptor(&ArrayDesc, Array));
@@ -1126,7 +1130,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
11261130
UR_CHECK_ERROR(RetImplEvent->start());
11271131
}
11281132

1129-
ur_mem_type_t ImgType = std::get<SurfaceMem>(hImage->Mem).getType();
1133+
ur_mem_type_t ImgType = Image.getType();
11301134
if (ImgType == UR_MEM_TYPE_IMAGE1D) {
11311135
UR_CHECK_ERROR(
11321136
cuMemcpyHtoAAsync(Array, ByteOffsetX, pSrc, BytesToCopy, CuStream));

source/adapters/cuda/memory.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -528,35 +528,35 @@ ur_result_t migrateImageToDevice(ur_mem_handle_t Mem,
528528
// dimensionality
529529
if (Image.ImageDesc.type == UR_MEM_TYPE_IMAGE2D) {
530530
memset(&CpyDesc2D, 0, sizeof(CpyDesc2D));
531-
CpyDesc2D.srcMemoryType = CUmemorytype_enum::CU_MEMORYTYPE_HOST;
532531
CpyDesc2D.srcHost = Image.HostPtr;
533532
CpyDesc2D.dstMemoryType = CUmemorytype_enum::CU_MEMORYTYPE_ARRAY;
534533
CpyDesc2D.dstArray = ImageArray;
535534
CpyDesc2D.WidthInBytes = PixelSizeBytes * Image.ImageDesc.width;
536535
CpyDesc2D.Height = Image.ImageDesc.height;
537-
UR_CHECK_ERROR(cuMemcpy2D(&CpyDesc2D));
538536
} else if (Image.ImageDesc.type == UR_MEM_TYPE_IMAGE3D) {
539537
memset(&CpyDesc3D, 0, sizeof(CpyDesc3D));
540-
CpyDesc3D.srcMemoryType = CUmemorytype_enum::CU_MEMORYTYPE_HOST;
541538
CpyDesc3D.srcHost = Image.HostPtr;
542539
CpyDesc3D.dstMemoryType = CUmemorytype_enum::CU_MEMORYTYPE_ARRAY;
543540
CpyDesc3D.dstArray = ImageArray;
544541
CpyDesc3D.WidthInBytes = PixelSizeBytes * Image.ImageDesc.width;
545542
CpyDesc3D.Height = Image.ImageDesc.height;
546543
CpyDesc3D.Depth = Image.ImageDesc.depth;
547-
UR_CHECK_ERROR(cuMemcpy3D(&CpyDesc3D));
548544
}
549545

550546
if (Mem->LastEventWritingToMemObj == nullptr) {
551-
if (Image.ImageDesc.type == UR_MEM_TYPE_IMAGE1D) {
552-
UR_CHECK_ERROR(
553-
cuMemcpyHtoA(ImageArray, 0, Image.HostPtr, ImageSizeBytes));
554-
} else if (Image.ImageDesc.type == UR_MEM_TYPE_IMAGE2D) {
555-
CpyDesc2D.srcHost = Image.HostPtr;
556-
UR_CHECK_ERROR(cuMemcpy2D(&CpyDesc2D));
557-
} else if (Image.ImageDesc.type == UR_MEM_TYPE_IMAGE3D) {
558-
CpyDesc3D.srcHost = Image.HostPtr;
559-
UR_CHECK_ERROR(cuMemcpy3D(&CpyDesc3D));
547+
if (Image.HostPtr) {
548+
if (Image.ImageDesc.type == UR_MEM_TYPE_IMAGE1D) {
549+
UR_CHECK_ERROR(
550+
cuMemcpyHtoA(ImageArray, 0, Image.HostPtr, ImageSizeBytes));
551+
} else if (Image.ImageDesc.type == UR_MEM_TYPE_IMAGE2D) {
552+
CpyDesc2D.srcMemoryType = CUmemorytype_enum::CU_MEMORYTYPE_HOST;
553+
CpyDesc2D.srcHost = Image.HostPtr;
554+
UR_CHECK_ERROR(cuMemcpy2D(&CpyDesc2D));
555+
} else if (Image.ImageDesc.type == UR_MEM_TYPE_IMAGE3D) {
556+
CpyDesc3D.srcMemoryType = CUmemorytype_enum::CU_MEMORYTYPE_HOST;
557+
CpyDesc3D.srcHost = Image.HostPtr;
558+
UR_CHECK_ERROR(cuMemcpy3D(&CpyDesc3D));
559+
}
560560
}
561561
} else if (Mem->LastEventWritingToMemObj->getQueue()->getDevice() !=
562562
hDevice) {

0 commit comments

Comments
 (0)