@@ -203,9 +203,10 @@ static ur_result_t enqueueCommandBufferFillHelper(
203203 }
204204 }
205205
206- UR_CHECK_ERROR (cuGraphAddMemsetNode (
207- &GraphNode, CommandBuffer->CudaGraph , DepsList.data (),
208- DepsList.size (), &NodeParams, CommandBuffer->Device ->getContext ()));
206+ UR_CHECK_ERROR (
207+ cuGraphAddMemsetNode (&GraphNode, CommandBuffer->CudaGraph ,
208+ DepsList.data (), DepsList.size (), &NodeParams,
209+ CommandBuffer->Device ->getNativeContext ()));
209210
210211 // Get sync point and register the cuNode with it.
211212 *SyncPoint =
@@ -237,7 +238,7 @@ static ur_result_t enqueueCommandBufferFillHelper(
237238 UR_CHECK_ERROR (cuGraphAddMemsetNode (
238239 &GraphNodeFirst, CommandBuffer->CudaGraph , DepsList.data (),
239240 DepsList.size (), &NodeParamsStepFirst,
240- CommandBuffer->Device ->getContext ()));
241+ CommandBuffer->Device ->getNativeContext ()));
241242
242243 // Get sync point and register the cuNode with it.
243244 *SyncPoint = CommandBuffer->addSyncPoint (
@@ -269,7 +270,7 @@ static ur_result_t enqueueCommandBufferFillHelper(
269270 UR_CHECK_ERROR (cuGraphAddMemsetNode (
270271 &GraphNode, CommandBuffer->CudaGraph , DepsList.data (),
271272 DepsList.size (), &NodeParamsStep,
272- CommandBuffer->Device ->getContext ()));
273+ CommandBuffer->Device ->getNativeContext ()));
273274
274275 GraphNodePtr = std::make_shared<CUgraphNode>(GraphNode);
275276 // Get sync point and register the cuNode with it.
@@ -478,7 +479,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMMemcpyExp(
478479
479480 UR_CHECK_ERROR (cuGraphAddMemcpyNode (
480481 &GraphNode, hCommandBuffer->CudaGraph , DepsList.data (), DepsList.size (),
481- &NodeParams, hCommandBuffer->Device ->getContext ()));
482+ &NodeParams, hCommandBuffer->Device ->getNativeContext ()));
482483
483484 // Get sync point and register the cuNode with it.
484485 *pSyncPoint =
@@ -513,16 +514,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp(
513514 }
514515
515516 try {
516- auto Src = std::get<BufferMem>(hSrcMem->Mem ).get () + srcOffset;
517- auto Dst = std::get<BufferMem>(hDstMem->Mem ).get () + dstOffset;
517+ auto Src = std::get<BufferMem>(hSrcMem->Mem )
518+ .getPtrWithOffset (hCommandBuffer->Device , srcOffset);
519+ auto Dst = std::get<BufferMem>(hDstMem->Mem )
520+ .getPtrWithOffset (hCommandBuffer->Device , dstOffset);
518521
519522 CUDA_MEMCPY3D NodeParams = {};
520523 setCopyParams (&Src, CU_MEMORYTYPE_DEVICE, &Dst, CU_MEMORYTYPE_DEVICE, size,
521524 NodeParams);
522525
523526 UR_CHECK_ERROR (cuGraphAddMemcpyNode (
524527 &GraphNode, hCommandBuffer->CudaGraph , DepsList.data (), DepsList.size (),
525- &NodeParams, hCommandBuffer->Device ->getContext ()));
528+ &NodeParams, hCommandBuffer->Device ->getNativeContext ()));
526529
527530 // Get sync point and register the cuNode with it.
528531 *pSyncPoint =
@@ -553,8 +556,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(
553556 }
554557
555558 try {
556- CUdeviceptr SrcPtr = std::get<BufferMem>(hSrcMem->Mem ).get ();
557- CUdeviceptr DstPtr = std::get<BufferMem>(hDstMem->Mem ).get ();
559+ auto SrcPtr =
560+ std::get<BufferMem>(hSrcMem->Mem ).getPtr (hCommandBuffer->Device );
561+ auto DstPtr =
562+ std::get<BufferMem>(hDstMem->Mem ).getPtr (hCommandBuffer->Device );
558563 CUDA_MEMCPY3D NodeParams = {};
559564
560565 setCopyRectParams (region, &SrcPtr, CU_MEMORYTYPE_DEVICE, srcOrigin,
@@ -563,7 +568,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(
563568
564569 UR_CHECK_ERROR (cuGraphAddMemcpyNode (
565570 &GraphNode, hCommandBuffer->CudaGraph , DepsList.data (), DepsList.size (),
566- &NodeParams, hCommandBuffer->Device ->getContext ()));
571+ &NodeParams, hCommandBuffer->Device ->getNativeContext ()));
567572
568573 // Get sync point and register the cuNode with it.
569574 *pSyncPoint =
@@ -593,15 +598,16 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteExp(
593598 }
594599
595600 try {
596- auto Dst = std::get<BufferMem>(hBuffer->Mem ).get () + offset;
601+ auto Dst = std::get<BufferMem>(hBuffer->Mem )
602+ .getPtrWithOffset (hCommandBuffer->Device , offset);
597603
598604 CUDA_MEMCPY3D NodeParams = {};
599605 setCopyParams (pSrc, CU_MEMORYTYPE_HOST, &Dst, CU_MEMORYTYPE_DEVICE, size,
600606 NodeParams);
601607
602608 UR_CHECK_ERROR (cuGraphAddMemcpyNode (
603609 &GraphNode, hCommandBuffer->CudaGraph , DepsList.data (), DepsList.size (),
604- &NodeParams, hCommandBuffer->Device ->getContext ()));
610+ &NodeParams, hCommandBuffer->Device ->getNativeContext ()));
605611
606612 // Get sync point and register the cuNode with it.
607613 *pSyncPoint =
@@ -630,15 +636,16 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadExp(
630636 }
631637
632638 try {
633- auto Src = std::get<BufferMem>(hBuffer->Mem ).get () + offset;
639+ auto Src = std::get<BufferMem>(hBuffer->Mem )
640+ .getPtrWithOffset (hCommandBuffer->Device , offset);
634641
635642 CUDA_MEMCPY3D NodeParams = {};
636643 setCopyParams (&Src, CU_MEMORYTYPE_DEVICE, pDst, CU_MEMORYTYPE_HOST, size,
637644 NodeParams);
638645
639646 UR_CHECK_ERROR (cuGraphAddMemcpyNode (
640647 &GraphNode, hCommandBuffer->CudaGraph , DepsList.data (), DepsList.size (),
641- &NodeParams, hCommandBuffer->Device ->getContext ()));
648+ &NodeParams, hCommandBuffer->Device ->getNativeContext ()));
642649
643650 // Get sync point and register the cuNode with it.
644651 *pSyncPoint =
@@ -670,7 +677,8 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp(
670677 }
671678
672679 try {
673- CUdeviceptr DstPtr = std::get<BufferMem>(hBuffer->Mem ).get ();
680+ auto DstPtr =
681+ std::get<BufferMem>(hBuffer->Mem ).getPtr (hCommandBuffer->Device );
674682 CUDA_MEMCPY3D NodeParams = {};
675683
676684 setCopyRectParams (region, pSrc, CU_MEMORYTYPE_HOST, hostOffset,
@@ -680,7 +688,7 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp(
680688
681689 UR_CHECK_ERROR (cuGraphAddMemcpyNode (
682690 &GraphNode, hCommandBuffer->CudaGraph , DepsList.data (), DepsList.size (),
683- &NodeParams, hCommandBuffer->Device ->getContext ()));
691+ &NodeParams, hCommandBuffer->Device ->getNativeContext ()));
684692
685693 // Get sync point and register the cuNode with it.
686694 *pSyncPoint =
@@ -712,7 +720,8 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
712720 }
713721
714722 try {
715- CUdeviceptr SrcPtr = std::get<BufferMem>(hBuffer->Mem ).get ();
723+ auto SrcPtr =
724+ std::get<BufferMem>(hBuffer->Mem ).getPtr (hCommandBuffer->Device );
716725 CUDA_MEMCPY3D NodeParams = {};
717726
718727 setCopyRectParams (region, &SrcPtr, CU_MEMORYTYPE_DEVICE, bufferOffset,
@@ -722,7 +731,7 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
722731
723732 UR_CHECK_ERROR (cuGraphAddMemcpyNode (
724733 &GraphNode, hCommandBuffer->CudaGraph , DepsList.data (), DepsList.size (),
725- &NodeParams, hCommandBuffer->Device ->getContext ()));
734+ &NodeParams, hCommandBuffer->Device ->getNativeContext ()));
726735
727736 // Get sync point and register the cuNode with it.
728737 *pSyncPoint =
@@ -821,7 +830,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
821830 PatternSizeIsValid,
822831 UR_RESULT_ERROR_INVALID_SIZE);
823832
824- auto DstDevice = std::get<BufferMem>(hBuffer->Mem ).get () + offset;
833+ auto DstDevice = std::get<BufferMem>(hBuffer->Mem )
834+ .getPtrWithOffset (hCommandBuffer->Device , offset);
825835
826836 return enqueueCommandBufferFillHelper (
827837 hCommandBuffer, &DstDevice, CU_MEMORYTYPE_DEVICE, pPattern, patternSize,
@@ -854,7 +864,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
854864
855865 try {
856866 std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
857- ScopedContext Active (hQueue->getContext ());
867+ ScopedContext Active (hQueue->getDevice ());
858868 uint32_t StreamToken;
859869 ur_stream_guard_ Guard;
860870 CUstream CuStream = hQueue->getNextComputeStream (
@@ -972,7 +982,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
972982 if (ArgValue == nullptr ) {
973983 Kernel->setKernelArg (ArgIndex, 0 , nullptr );
974984 } else {
975- CUdeviceptr CuPtr = std::get<BufferMem>(ArgValue->Mem ).get ();
985+ CUdeviceptr CuPtr =
986+ std::get<BufferMem>(ArgValue->Mem ).getPtr (CommandBuffer->Device );
976987 Kernel->setKernelArg (ArgIndex, sizeof (CUdeviceptr), (void *)&CuPtr);
977988 }
978989 } catch (ur_result_t Err) {
0 commit comments