1212
1313#include " common.hpp"
1414#include " context.hpp"
15+ #include " enqueue.hpp"
1516#include " memory.hpp"
1617
1718// / Creates a UR Memory object using a CUDA memory allocation.
@@ -238,7 +239,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemImageCreate(
238239 try {
239240 if (PerformInitialCopy) {
240241 for (const auto &Device : hContext->getDevices ()) {
241- UR_CHECK_ERROR (migrateMemoryToDeviceIfNeeded (URMemObj.get (), Device));
242+ // Synchronous behaviour is best in this case
243+ ScopedContext Active (Device);
244+ CUstream Stream{0 }; // Use default stream
245+ UR_CHECK_ERROR (enqueueMigrateMemoryToDeviceIfNeeded (URMemObj.get (),
246+ Device, Stream));
247+ UR_CHECK_ERROR (cuStreamSynchronize (Stream));
242248 }
243249 }
244250
@@ -496,27 +502,28 @@ ur_result_t allocateMemObjOnDeviceIfNeeded(ur_mem_handle_t Mem,
496502}
497503
498504namespace {
499- ur_result_t migrateBufferToDevice (ur_mem_handle_t Mem,
500- ur_device_handle_t hDevice) {
505+ ur_result_t enqueueMigrateBufferToDevice (ur_mem_handle_t Mem,
506+ ur_device_handle_t hDevice,
507+ CUstream Stream) {
501508 auto &Buffer = std::get<BufferMem>(Mem->Mem );
502- if (Mem->LastEventWritingToMemObj == nullptr ) {
509+ if (Mem->LastQueueWritingToMemObj == nullptr ) {
503510 // Device allocation being initialized from host for the first time
504511 if (Buffer.HostPtr ) {
505- UR_CHECK_ERROR (
506- cuMemcpyHtoD (Buffer. getPtr (hDevice), Buffer. HostPtr , Buffer.Size ));
512+ UR_CHECK_ERROR (cuMemcpyHtoDAsync (Buffer. getPtr (hDevice), Buffer. HostPtr ,
513+ Buffer.Size , Stream ));
507514 }
508- } else if (Mem->LastEventWritingToMemObj ->getQueue ()->getDevice () !=
509- hDevice) {
510- UR_CHECK_ERROR (cuMemcpyDtoD (
515+ } else if (Mem->LastQueueWritingToMemObj ->getDevice () != hDevice) {
516+ UR_CHECK_ERROR (cuMemcpyDtoDAsync (
511517 Buffer.getPtr (hDevice),
512- Buffer.getPtr (Mem->LastEventWritingToMemObj -> getQueue ()-> getDevice ()),
513- Buffer. Size ));
518+ Buffer.getPtr (Mem->LastQueueWritingToMemObj -> getDevice ()), Buffer. Size ,
519+ Stream ));
514520 }
515521 return UR_RESULT_SUCCESS;
516522}
517523
518- ur_result_t migrateImageToDevice (ur_mem_handle_t Mem,
519- ur_device_handle_t hDevice) {
524+ ur_result_t enqueueMigrateImageToDevice (ur_mem_handle_t Mem,
525+ ur_device_handle_t hDevice,
526+ CUstream Stream) {
520527 auto &Image = std::get<SurfaceMem>(Mem->Mem );
521528 // When a dimension isn't used image_desc has the size set to 1
522529 size_t PixelSizeBytes = Image.PixelTypeSizeBytes *
@@ -547,40 +554,42 @@ ur_result_t migrateImageToDevice(ur_mem_handle_t Mem,
547554 CpyDesc3D.Depth = Image.ImageDesc .depth ;
548555 }
549556
550- if (Mem->LastEventWritingToMemObj == nullptr ) {
557+ if (Mem->LastQueueWritingToMemObj == nullptr ) {
551558 if (Image.HostPtr ) {
552559 if (Image.ImageDesc .type == UR_MEM_TYPE_IMAGE1D) {
553- UR_CHECK_ERROR (
554- cuMemcpyHtoA (ImageArray, 0 , Image. HostPtr , ImageSizeBytes));
560+ UR_CHECK_ERROR (cuMemcpyHtoAAsync (ImageArray, 0 , Image. HostPtr ,
561+ ImageSizeBytes, Stream ));
555562 } else if (Image.ImageDesc .type == UR_MEM_TYPE_IMAGE2D) {
556563 CpyDesc2D.srcMemoryType = CUmemorytype_enum::CU_MEMORYTYPE_HOST;
557564 CpyDesc2D.srcHost = Image.HostPtr ;
558- UR_CHECK_ERROR (cuMemcpy2D (&CpyDesc2D));
565+ UR_CHECK_ERROR (cuMemcpy2DAsync (&CpyDesc2D, Stream ));
559566 } else if (Image.ImageDesc .type == UR_MEM_TYPE_IMAGE3D) {
560567 CpyDesc3D.srcMemoryType = CUmemorytype_enum::CU_MEMORYTYPE_HOST;
561568 CpyDesc3D.srcHost = Image.HostPtr ;
562- UR_CHECK_ERROR (cuMemcpy3D (&CpyDesc3D));
569+ UR_CHECK_ERROR (cuMemcpy3DAsync (&CpyDesc3D, Stream ));
563570 }
564571 }
565- } else if (Mem->LastEventWritingToMemObj ->getQueue ()->getDevice () !=
566- hDevice) {
572+ } else if (Mem->LastQueueWritingToMemObj ->getDevice () != hDevice) {
567573 if (Image.ImageDesc .type == UR_MEM_TYPE_IMAGE1D) {
574+ // Blocking wait needed
575+ UR_CHECK_ERROR (urQueueFinish (Mem->LastQueueWritingToMemObj ));
568576 // FIXME: 1D memcpy from DtoD going through the host.
569577 UR_CHECK_ERROR (cuMemcpyAtoH (
570578 Image.HostPtr ,
571- Image.getArray (
572- Mem->LastEventWritingToMemObj ->getQueue ()->getDevice ()),
579+ Image.getArray (Mem->LastQueueWritingToMemObj ->getDevice ()),
573580 0 /* srcOffset*/ , ImageSizeBytes));
574581 UR_CHECK_ERROR (
575582 cuMemcpyHtoA (ImageArray, 0 , Image.HostPtr , ImageSizeBytes));
576583 } else if (Image.ImageDesc .type == UR_MEM_TYPE_IMAGE2D) {
577- CpyDesc2D.srcArray = Image.getArray (
578- Mem->LastEventWritingToMemObj ->getQueue ()->getDevice ());
579- UR_CHECK_ERROR (cuMemcpy2D (&CpyDesc2D));
584+ CpyDesc2D.srcMemoryType = CUmemorytype_enum::CU_MEMORYTYPE_DEVICE;
585+ CpyDesc2D.srcArray =
586+ Image.getArray (Mem->LastQueueWritingToMemObj ->getDevice ());
587+ UR_CHECK_ERROR (cuMemcpy2DAsync (&CpyDesc2D, Stream));
580588 } else if (Image.ImageDesc .type == UR_MEM_TYPE_IMAGE3D) {
581- CpyDesc3D.srcArray = Image.getArray (
582- Mem->LastEventWritingToMemObj ->getQueue ()->getDevice ());
583- UR_CHECK_ERROR (cuMemcpy3D (&CpyDesc3D));
589+ CpyDesc3D.srcMemoryType = CUmemorytype_enum::CU_MEMORYTYPE_DEVICE;
590+ CpyDesc3D.srcArray =
591+ Image.getArray (Mem->LastQueueWritingToMemObj ->getDevice ());
592+ UR_CHECK_ERROR (cuMemcpy3DAsync (&CpyDesc3D, Stream));
584593 }
585594 }
586595 return UR_RESULT_SUCCESS;
@@ -589,8 +598,8 @@ ur_result_t migrateImageToDevice(ur_mem_handle_t Mem,
589598
590599// If calling this entry point it is necessary to lock the memoryMigrationMutex
591600// beforehand
592- ur_result_t migrateMemoryToDeviceIfNeeded ( ur_mem_handle_t Mem,
593- const ur_device_handle_t hDevice) {
601+ ur_result_t enqueueMigrateMemoryToDeviceIfNeeded (
602+ ur_mem_handle_t Mem, const ur_device_handle_t hDevice, CUstream Stream ) {
594603 UR_ASSERT (hDevice, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
595604 // Device allocation has already been initialized with most up to date
596605 // data in buffer
@@ -601,9 +610,9 @@ ur_result_t migrateMemoryToDeviceIfNeeded(ur_mem_handle_t Mem,
601610
602611 ScopedContext Active (hDevice);
603612 if (Mem->isBuffer ()) {
604- UR_CHECK_ERROR (migrateBufferToDevice (Mem, hDevice));
613+ UR_CHECK_ERROR (enqueueMigrateBufferToDevice (Mem, hDevice, Stream ));
605614 } else {
606- UR_CHECK_ERROR (migrateImageToDevice (Mem, hDevice));
615+ UR_CHECK_ERROR (enqueueMigrateImageToDevice (Mem, hDevice, Stream ));
607616 }
608617
609618 Mem->HaveMigratedToDeviceSinceLastWrite [Mem->getContext ()->getDeviceIndex (
0 commit comments