@@ -203,9 +203,10 @@ static ur_result_t enqueueCommandBufferFillHelper(
203
203
}
204
204
}
205
205
206
- UR_CHECK_ERROR (cuGraphAddMemsetNode (
207
- &GraphNode, CommandBuffer->CudaGraph , DepsList.data (),
208
- DepsList.size (), &NodeParams, CommandBuffer->Device ->getContext ()));
206
+ UR_CHECK_ERROR (
207
+ cuGraphAddMemsetNode (&GraphNode, CommandBuffer->CudaGraph ,
208
+ DepsList.data (), DepsList.size (), &NodeParams,
209
+ CommandBuffer->Device ->getNativeContext ()));
209
210
210
211
// Get sync point and register the cuNode with it.
211
212
*SyncPoint =
@@ -237,7 +238,7 @@ static ur_result_t enqueueCommandBufferFillHelper(
237
238
UR_CHECK_ERROR (cuGraphAddMemsetNode (
238
239
&GraphNodeFirst, CommandBuffer->CudaGraph , DepsList.data (),
239
240
DepsList.size (), &NodeParamsStepFirst,
240
- CommandBuffer->Device ->getContext ()));
241
+ CommandBuffer->Device ->getNativeContext ()));
241
242
242
243
// Get sync point and register the cuNode with it.
243
244
*SyncPoint = CommandBuffer->addSyncPoint (
@@ -269,7 +270,7 @@ static ur_result_t enqueueCommandBufferFillHelper(
269
270
UR_CHECK_ERROR (cuGraphAddMemsetNode (
270
271
&GraphNode, CommandBuffer->CudaGraph , DepsList.data (),
271
272
DepsList.size (), &NodeParamsStep,
272
- CommandBuffer->Device ->getContext ()));
273
+ CommandBuffer->Device ->getNativeContext ()));
273
274
274
275
GraphNodePtr = std::make_shared<CUgraphNode>(GraphNode);
275
276
// Get sync point and register the cuNode with it.
@@ -478,7 +479,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMMemcpyExp(
478
479
479
480
UR_CHECK_ERROR (cuGraphAddMemcpyNode (
480
481
&GraphNode, hCommandBuffer->CudaGraph , DepsList.data (), DepsList.size (),
481
- &NodeParams, hCommandBuffer->Device ->getContext ()));
482
+ &NodeParams, hCommandBuffer->Device ->getNativeContext ()));
482
483
483
484
// Get sync point and register the cuNode with it.
484
485
*pSyncPoint =
@@ -513,16 +514,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp(
513
514
}
514
515
515
516
try {
516
- auto Src = std::get<BufferMem>(hSrcMem->Mem ).get () + srcOffset;
517
- auto Dst = std::get<BufferMem>(hDstMem->Mem ).get () + dstOffset;
517
+ auto Src = std::get<BufferMem>(hSrcMem->Mem )
518
+ .getPtrWithOffset (hCommandBuffer->Device , srcOffset);
519
+ auto Dst = std::get<BufferMem>(hDstMem->Mem )
520
+ .getPtrWithOffset (hCommandBuffer->Device , dstOffset);
518
521
519
522
CUDA_MEMCPY3D NodeParams = {};
520
523
setCopyParams (&Src, CU_MEMORYTYPE_DEVICE, &Dst, CU_MEMORYTYPE_DEVICE, size,
521
524
NodeParams);
522
525
523
526
UR_CHECK_ERROR (cuGraphAddMemcpyNode (
524
527
&GraphNode, hCommandBuffer->CudaGraph , DepsList.data (), DepsList.size (),
525
- &NodeParams, hCommandBuffer->Device ->getContext ()));
528
+ &NodeParams, hCommandBuffer->Device ->getNativeContext ()));
526
529
527
530
// Get sync point and register the cuNode with it.
528
531
*pSyncPoint =
@@ -553,8 +556,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(
553
556
}
554
557
555
558
try {
556
- CUdeviceptr SrcPtr = std::get<BufferMem>(hSrcMem->Mem ).get ();
557
- CUdeviceptr DstPtr = std::get<BufferMem>(hDstMem->Mem ).get ();
559
+ auto SrcPtr =
560
+ std::get<BufferMem>(hSrcMem->Mem ).getPtr (hCommandBuffer->Device );
561
+ auto DstPtr =
562
+ std::get<BufferMem>(hDstMem->Mem ).getPtr (hCommandBuffer->Device );
558
563
CUDA_MEMCPY3D NodeParams = {};
559
564
560
565
setCopyRectParams (region, &SrcPtr, CU_MEMORYTYPE_DEVICE, srcOrigin,
@@ -563,7 +568,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(
563
568
564
569
UR_CHECK_ERROR (cuGraphAddMemcpyNode (
565
570
&GraphNode, hCommandBuffer->CudaGraph , DepsList.data (), DepsList.size (),
566
- &NodeParams, hCommandBuffer->Device ->getContext ()));
571
+ &NodeParams, hCommandBuffer->Device ->getNativeContext ()));
567
572
568
573
// Get sync point and register the cuNode with it.
569
574
*pSyncPoint =
@@ -593,15 +598,16 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteExp(
593
598
}
594
599
595
600
try {
596
- auto Dst = std::get<BufferMem>(hBuffer->Mem ).get () + offset;
601
+ auto Dst = std::get<BufferMem>(hBuffer->Mem )
602
+ .getPtrWithOffset (hCommandBuffer->Device , offset);
597
603
598
604
CUDA_MEMCPY3D NodeParams = {};
599
605
setCopyParams (pSrc, CU_MEMORYTYPE_HOST, &Dst, CU_MEMORYTYPE_DEVICE, size,
600
606
NodeParams);
601
607
602
608
UR_CHECK_ERROR (cuGraphAddMemcpyNode (
603
609
&GraphNode, hCommandBuffer->CudaGraph , DepsList.data (), DepsList.size (),
604
- &NodeParams, hCommandBuffer->Device ->getContext ()));
610
+ &NodeParams, hCommandBuffer->Device ->getNativeContext ()));
605
611
606
612
// Get sync point and register the cuNode with it.
607
613
*pSyncPoint =
@@ -630,15 +636,16 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadExp(
630
636
}
631
637
632
638
try {
633
- auto Src = std::get<BufferMem>(hBuffer->Mem ).get () + offset;
639
+ auto Src = std::get<BufferMem>(hBuffer->Mem )
640
+ .getPtrWithOffset (hCommandBuffer->Device , offset);
634
641
635
642
CUDA_MEMCPY3D NodeParams = {};
636
643
setCopyParams (&Src, CU_MEMORYTYPE_DEVICE, pDst, CU_MEMORYTYPE_HOST, size,
637
644
NodeParams);
638
645
639
646
UR_CHECK_ERROR (cuGraphAddMemcpyNode (
640
647
&GraphNode, hCommandBuffer->CudaGraph , DepsList.data (), DepsList.size (),
641
- &NodeParams, hCommandBuffer->Device ->getContext ()));
648
+ &NodeParams, hCommandBuffer->Device ->getNativeContext ()));
642
649
643
650
// Get sync point and register the cuNode with it.
644
651
*pSyncPoint =
@@ -670,7 +677,8 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp(
670
677
}
671
678
672
679
try {
673
- CUdeviceptr DstPtr = std::get<BufferMem>(hBuffer->Mem ).get ();
680
+ auto DstPtr =
681
+ std::get<BufferMem>(hBuffer->Mem ).getPtr (hCommandBuffer->Device );
674
682
CUDA_MEMCPY3D NodeParams = {};
675
683
676
684
setCopyRectParams (region, pSrc, CU_MEMORYTYPE_HOST, hostOffset,
@@ -680,7 +688,7 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp(
680
688
681
689
UR_CHECK_ERROR (cuGraphAddMemcpyNode (
682
690
&GraphNode, hCommandBuffer->CudaGraph , DepsList.data (), DepsList.size (),
683
- &NodeParams, hCommandBuffer->Device ->getContext ()));
691
+ &NodeParams, hCommandBuffer->Device ->getNativeContext ()));
684
692
685
693
// Get sync point and register the cuNode with it.
686
694
*pSyncPoint =
@@ -712,7 +720,8 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
712
720
}
713
721
714
722
try {
715
- CUdeviceptr SrcPtr = std::get<BufferMem>(hBuffer->Mem ).get ();
723
+ auto SrcPtr =
724
+ std::get<BufferMem>(hBuffer->Mem ).getPtr (hCommandBuffer->Device );
716
725
CUDA_MEMCPY3D NodeParams = {};
717
726
718
727
setCopyRectParams (region, &SrcPtr, CU_MEMORYTYPE_DEVICE, bufferOffset,
@@ -722,7 +731,7 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
722
731
723
732
UR_CHECK_ERROR (cuGraphAddMemcpyNode (
724
733
&GraphNode, hCommandBuffer->CudaGraph , DepsList.data (), DepsList.size (),
725
- &NodeParams, hCommandBuffer->Device ->getContext ()));
734
+ &NodeParams, hCommandBuffer->Device ->getNativeContext ()));
726
735
727
736
// Get sync point and register the cuNode with it.
728
737
*pSyncPoint =
@@ -821,7 +830,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
821
830
PatternSizeIsValid,
822
831
UR_RESULT_ERROR_INVALID_SIZE);
823
832
824
- auto DstDevice = std::get<BufferMem>(hBuffer->Mem ).get () + offset;
833
+ auto DstDevice = std::get<BufferMem>(hBuffer->Mem )
834
+ .getPtrWithOffset (hCommandBuffer->Device , offset);
825
835
826
836
return enqueueCommandBufferFillHelper (
827
837
hCommandBuffer, &DstDevice, CU_MEMORYTYPE_DEVICE, pPattern, patternSize,
@@ -854,7 +864,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
854
864
855
865
try {
856
866
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
857
- ScopedContext Active (hQueue->getContext ());
867
+ ScopedContext Active (hQueue->getDevice ());
858
868
uint32_t StreamToken;
859
869
ur_stream_guard_ Guard;
860
870
CUstream CuStream = hQueue->getNextComputeStream (
@@ -972,7 +982,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
972
982
if (ArgValue == nullptr ) {
973
983
Kernel->setKernelArg (ArgIndex, 0 , nullptr );
974
984
} else {
975
- CUdeviceptr CuPtr = std::get<BufferMem>(ArgValue->Mem ).get ();
985
+ CUdeviceptr CuPtr =
986
+ std::get<BufferMem>(ArgValue->Mem ).getPtr (CommandBuffer->Device );
976
987
Kernel->setKernelArg (ArgIndex, sizeof (CUdeviceptr), (void *)&CuPtr);
977
988
}
978
989
} catch (ur_result_t Err) {
0 commit comments