@@ -99,6 +99,91 @@ static void setCopyParams(const void *SrcPtr, const CUmemorytype_enum SrcType,
99
99
Params.Depth = 1 ;
100
100
}
101
101
102
+ // Helper function for enqueuing memory fills
103
+ static ur_result_t enqueueCommandBufferFillHelper (
104
+ ur_exp_command_buffer_handle_t CommandBuffer, void *DstDevice,
105
+ const CUmemorytype_enum DstType, const void *Pattern, size_t PatternSize,
106
+ size_t Size, uint32_t NumSyncPointsInWaitList,
107
+ const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
108
+ ur_exp_command_buffer_sync_point_t *SyncPoint) {
109
+ ur_result_t Result = UR_RESULT_SUCCESS;
110
+ std::vector<CUgraphNode> DepsList;
111
+ UR_CALL (getNodesFromSyncPoints (CommandBuffer, NumSyncPointsInWaitList,
112
+ SyncPointWaitList, DepsList),
113
+ Result);
114
+
115
+ try {
116
+ const size_t N = Size / PatternSize;
117
+ auto Value = *static_cast <const uint32_t *>(Pattern);
118
+ auto DstPtr = DstType == CU_MEMORYTYPE_DEVICE
119
+ ? *static_cast <CUdeviceptr *>(DstDevice)
120
+ : (CUdeviceptr)DstDevice;
121
+
122
+ if ((PatternSize == 1 ) || (PatternSize == 2 ) || (PatternSize == 4 )) {
123
+ // Create a new node
124
+ CUgraphNode GraphNode;
125
+ CUDA_MEMSET_NODE_PARAMS NodeParams = {};
126
+ NodeParams.dst = DstPtr;
127
+ NodeParams.elementSize = PatternSize;
128
+ NodeParams.height = N;
129
+ NodeParams.pitch = PatternSize;
130
+ NodeParams.value = Value;
131
+ NodeParams.width = 1 ;
132
+
133
+ UR_CHECK_ERROR (cuGraphAddMemsetNode (
134
+ &GraphNode, CommandBuffer->CudaGraph , DepsList.data (),
135
+ DepsList.size (), &NodeParams, CommandBuffer->Device ->getContext ()));
136
+
137
+ // Get sync point and register the cuNode with it.
138
+ *SyncPoint =
139
+ CommandBuffer->AddSyncPoint (std::make_shared<CUgraphNode>(GraphNode));
140
+
141
+ } else {
142
+ // CUDA has no memset functions that allow setting values more than 4
143
+ // bytes. UR API lets you pass an arbitrary "pattern" to the buffer
144
+ // fill, which can be more than 4 bytes. We must break up the pattern
145
+ // into 4 byte values, and set the buffer using multiple strided calls.
146
+ // This means that one cuGraphAddMemsetNode call is made for every 4 bytes
147
+ // in the pattern.
148
+
149
+ size_t NumberOfSteps = PatternSize / sizeof (uint32_t );
150
+
151
+ // we walk up the pattern in 4-byte steps, and call cuMemset for each
152
+ // 4-byte chunk of the pattern.
153
+ for (auto Step = 0u ; Step < NumberOfSteps; ++Step) {
154
+ // take 4 bytes of the pattern
155
+ auto Value = *(static_cast <const uint32_t *>(Pattern) + Step);
156
+
157
+ // offset the pointer to the part of the buffer we want to write to
158
+ auto OffsetPtr = DstPtr + (Step * sizeof (uint32_t ));
159
+
160
+ // Create a new node
161
+ CUgraphNode GraphNode;
162
+ // Update NodeParam
163
+ CUDA_MEMSET_NODE_PARAMS NodeParamsStep = {};
164
+ NodeParamsStep.dst = (CUdeviceptr)OffsetPtr;
165
+ NodeParamsStep.elementSize = 4 ;
166
+ NodeParamsStep.height = N;
167
+ NodeParamsStep.pitch = PatternSize;
168
+ NodeParamsStep.value = Value;
169
+ NodeParamsStep.width = 1 ;
170
+
171
+ UR_CHECK_ERROR (cuGraphAddMemsetNode (
172
+ &GraphNode, CommandBuffer->CudaGraph , DepsList.data (),
173
+ DepsList.size (), &NodeParamsStep,
174
+ CommandBuffer->Device ->getContext ()));
175
+
176
+ // Get sync point and register the cuNode with it.
177
+ *SyncPoint = CommandBuffer->AddSyncPoint (
178
+ std::make_shared<CUgraphNode>(GraphNode));
179
+ }
180
+ }
181
+ } catch (ur_result_t Err) {
182
+ Result = Err;
183
+ }
184
+ return Result;
185
+ }
186
+
102
187
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp (
103
188
ur_context_handle_t hContext, ur_device_handle_t hDevice,
104
189
const ur_exp_command_buffer_desc_t *pCommandBufferDesc,
@@ -525,6 +610,119 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
525
610
return Result;
526
611
}
527
612
613
+ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp (
614
+ ur_exp_command_buffer_handle_t hCommandBuffer, const void * /* Mem */ ,
615
+ size_t /* Size*/ , ur_usm_migration_flags_t /* Flags*/ ,
616
+ uint32_t numSyncPointsInWaitList,
617
+ const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
618
+ ur_exp_command_buffer_sync_point_t *pSyncPoint) {
619
+ // Prefetch cmd is not supported by Cuda Graph.
620
+ // We implement it as an empty node to enforce dependencies.
621
+ ur_result_t Result = UR_RESULT_SUCCESS;
622
+ CUgraphNode GraphNode;
623
+
624
+ std::vector<CUgraphNode> DepsList;
625
+ UR_CALL (getNodesFromSyncPoints (hCommandBuffer, numSyncPointsInWaitList,
626
+ pSyncPointWaitList, DepsList),
627
+ Result);
628
+
629
+ try {
630
+ // Add an empty node to preserve dependencies.
631
+ UR_CHECK_ERROR (cuGraphAddEmptyNode (&GraphNode, hCommandBuffer->CudaGraph ,
632
+ DepsList.data (), DepsList.size ()));
633
+
634
+ // Get sync point and register the cuNode with it.
635
+ *pSyncPoint =
636
+ hCommandBuffer->AddSyncPoint (std::make_shared<CUgraphNode>(GraphNode));
637
+
638
+ setErrorMessage (" Prefetch hint ignored and replaced with empty node as "
639
+ " prefetch is not supported by CUDA Graph backend" ,
640
+ UR_RESULT_SUCCESS);
641
+ Result = UR_RESULT_ERROR_ADAPTER_SPECIFIC;
642
+ } catch (ur_result_t Err) {
643
+ Result = Err;
644
+ }
645
+ return Result;
646
+ }
647
+
648
+ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp (
649
+ ur_exp_command_buffer_handle_t hCommandBuffer, const void * /* Mem */ ,
650
+ size_t /* Size*/ , ur_usm_advice_flags_t /* Advice*/ ,
651
+ uint32_t numSyncPointsInWaitList,
652
+ const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
653
+ ur_exp_command_buffer_sync_point_t *pSyncPoint) {
654
+ // Mem-Advise cmd is not supported by Cuda Graph.
655
+ // We implement it as an empty node to enforce dependencies.
656
+ ur_result_t Result = UR_RESULT_SUCCESS;
657
+ CUgraphNode GraphNode;
658
+
659
+ std::vector<CUgraphNode> DepsList;
660
+ UR_CALL (getNodesFromSyncPoints (hCommandBuffer, numSyncPointsInWaitList,
661
+ pSyncPointWaitList, DepsList),
662
+ Result);
663
+
664
+ try {
665
+ // Add an empty node to preserve dependencies.
666
+ UR_CHECK_ERROR (cuGraphAddEmptyNode (&GraphNode, hCommandBuffer->CudaGraph ,
667
+ DepsList.data (), DepsList.size ()));
668
+
669
+ // Get sync point and register the cuNode with it.
670
+ *pSyncPoint =
671
+ hCommandBuffer->AddSyncPoint (std::make_shared<CUgraphNode>(GraphNode));
672
+
673
+ setErrorMessage (" Memory advice ignored and replaced with empty node as "
674
+ " memory advice is not supported by CUDA Graph backend" ,
675
+ UR_RESULT_SUCCESS);
676
+ Result = UR_RESULT_ERROR_ADAPTER_SPECIFIC;
677
+ } catch (ur_result_t Err) {
678
+ Result = Err;
679
+ }
680
+
681
+ return Result;
682
+ }
683
+
684
+ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp (
685
+ ur_exp_command_buffer_handle_t hCommandBuffer, ur_mem_handle_t hBuffer,
686
+ const void *pPattern, size_t patternSize, size_t offset, size_t size,
687
+ uint32_t numSyncPointsInWaitList,
688
+ const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
689
+ ur_exp_command_buffer_sync_point_t *pSyncPoint) {
690
+ auto ArgsAreMultiplesOfPatternSize =
691
+ (offset % patternSize == 0 ) || (size % patternSize == 0 );
692
+
693
+ auto PatternIsValid = (pPattern != nullptr );
694
+
695
+ auto PatternSizeIsValid = ((patternSize & (patternSize - 1 )) == 0 ) &&
696
+ (patternSize > 0 ); // is a positive power of two
697
+ UR_ASSERT (ArgsAreMultiplesOfPatternSize && PatternIsValid &&
698
+ PatternSizeIsValid,
699
+ UR_RESULT_ERROR_INVALID_SIZE);
700
+
701
+ auto DstDevice = std::get<BufferMem>(hBuffer->Mem ).get () + offset;
702
+
703
+ return enqueueCommandBufferFillHelper (
704
+ hCommandBuffer, &DstDevice, CU_MEMORYTYPE_DEVICE, pPattern, patternSize,
705
+ size, numSyncPointsInWaitList, pSyncPointWaitList, pSyncPoint);
706
+ }
707
+
708
+ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp (
709
+ ur_exp_command_buffer_handle_t hCommandBuffer, void *pPtr,
710
+ const void *pPattern, size_t patternSize, size_t size,
711
+ uint32_t numSyncPointsInWaitList,
712
+ const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
713
+ ur_exp_command_buffer_sync_point_t *pSyncPoint) {
714
+
715
+ auto PatternIsValid = (pPattern != nullptr );
716
+
717
+ auto PatternSizeIsValid = ((patternSize & (patternSize - 1 )) == 0 ) &&
718
+ (patternSize > 0 ); // is a positive power of two
719
+
720
+ UR_ASSERT (PatternIsValid && PatternSizeIsValid, UR_RESULT_ERROR_INVALID_SIZE);
721
+ return enqueueCommandBufferFillHelper (
722
+ hCommandBuffer, pPtr, CU_MEMORYTYPE_UNIFIED, pPattern, patternSize, size,
723
+ numSyncPointsInWaitList, pSyncPointWaitList, pSyncPoint);
724
+ }
725
+
528
726
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp (
529
727
ur_exp_command_buffer_handle_t hCommandBuffer, ur_queue_handle_t hQueue,
530
728
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
0 commit comments