@@ -99,6 +99,91 @@ static void setCopyParams(const void *SrcPtr, const CUmemorytype_enum SrcType,
9999 Params.Depth = 1 ;
100100}
101101
102+ // Helper function for enqueuing memory fills
103+ static ur_result_t enqueueCommandBufferFillHelper (
104+ ur_exp_command_buffer_handle_t CommandBuffer, void *DstDevice,
105+ const CUmemorytype_enum DstType, const void *Pattern, size_t PatternSize,
106+ size_t Size, uint32_t NumSyncPointsInWaitList,
107+ const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
108+ ur_exp_command_buffer_sync_point_t *SyncPoint) {
109+ ur_result_t Result = UR_RESULT_SUCCESS;
110+ std::vector<CUgraphNode> DepsList;
111+ UR_CALL (getNodesFromSyncPoints (CommandBuffer, NumSyncPointsInWaitList,
112+ SyncPointWaitList, DepsList),
113+ Result);
114+
115+ try {
116+ const size_t N = Size / PatternSize;
117+ auto Value = *static_cast <const uint32_t *>(Pattern);
118+ auto DstPtr = DstType == CU_MEMORYTYPE_DEVICE
119+ ? *static_cast <CUdeviceptr *>(DstDevice)
120+ : (CUdeviceptr)DstDevice;
121+
122+ if ((PatternSize == 1 ) || (PatternSize == 2 ) || (PatternSize == 4 )) {
123+ // Create a new node
124+ CUgraphNode GraphNode;
125+ CUDA_MEMSET_NODE_PARAMS NodeParams = {};
126+ NodeParams.dst = DstPtr;
127+ NodeParams.elementSize = PatternSize;
128+ NodeParams.height = N;
129+ NodeParams.pitch = PatternSize;
130+ NodeParams.value = Value;
131+ NodeParams.width = 1 ;
132+
133+ UR_CHECK_ERROR (cuGraphAddMemsetNode (
134+ &GraphNode, CommandBuffer->CudaGraph , DepsList.data (),
135+ DepsList.size (), &NodeParams, CommandBuffer->Device ->getContext ()));
136+
137+ // Get sync point and register the cuNode with it.
138+ *SyncPoint =
139+ CommandBuffer->AddSyncPoint (std::make_shared<CUgraphNode>(GraphNode));
140+
141+ } else {
142+ // CUDA has no memset functions that allow setting values more than 4
143+ // bytes. UR API lets you pass an arbitrary "pattern" to the buffer
144+ // fill, which can be more than 4 bytes. We must break up the pattern
145+ // into 4 byte values, and set the buffer using multiple strided calls.
146+ // This means that one cuGraphAddMemsetNode call is made for every 4 bytes
147+ // in the pattern.
148+
149+ size_t NumberOfSteps = PatternSize / sizeof (uint32_t );
150+
151+ // we walk up the pattern in 4-byte steps, and call cuMemset for each
152+ // 4-byte chunk of the pattern.
153+ for (auto Step = 0u ; Step < NumberOfSteps; ++Step) {
154+ // take 4 bytes of the pattern
155+ auto Value = *(static_cast <const uint32_t *>(Pattern) + Step);
156+
157+ // offset the pointer to the part of the buffer we want to write to
158+ auto OffsetPtr = DstPtr + (Step * sizeof (uint32_t ));
159+
160+ // Create a new node
161+ CUgraphNode GraphNode;
162+ // Update NodeParam
163+ CUDA_MEMSET_NODE_PARAMS NodeParamsStep = {};
164+ NodeParamsStep.dst = (CUdeviceptr)OffsetPtr;
165+ NodeParamsStep.elementSize = 4 ;
166+ NodeParamsStep.height = N;
167+ NodeParamsStep.pitch = PatternSize;
168+ NodeParamsStep.value = Value;
169+ NodeParamsStep.width = 1 ;
170+
171+ UR_CHECK_ERROR (cuGraphAddMemsetNode (
172+ &GraphNode, CommandBuffer->CudaGraph , DepsList.data (),
173+ DepsList.size (), &NodeParamsStep,
174+ CommandBuffer->Device ->getContext ()));
175+
176+ // Get sync point and register the cuNode with it.
177+ *SyncPoint = CommandBuffer->AddSyncPoint (
178+ std::make_shared<CUgraphNode>(GraphNode));
179+ }
180+ }
181+ } catch (ur_result_t Err) {
182+ Result = Err;
183+ }
184+ return Result;
185+ }
186+
102187UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp (
103188 ur_context_handle_t hContext, ur_device_handle_t hDevice,
104189 const ur_exp_command_buffer_desc_t *pCommandBufferDesc,
@@ -596,6 +681,48 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
596681 return Result;
597682}
598683
684+ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp (
685+ ur_exp_command_buffer_handle_t hCommandBuffer, ur_mem_handle_t hBuffer,
686+ const void *pPattern, size_t patternSize, size_t offset, size_t size,
687+ uint32_t numSyncPointsInWaitList,
688+ const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
689+ ur_exp_command_buffer_sync_point_t *pSyncPoint) {
690+ auto ArgsAreMultiplesOfPatternSize =
691+ (offset % patternSize == 0 ) || (size % patternSize == 0 );
692+
693+ auto PatternIsValid = (pPattern != nullptr );
694+
695+ auto PatternSizeIsValid = ((patternSize & (patternSize - 1 )) == 0 ) &&
696+ (patternSize > 0 ); // is a positive power of two
697+ UR_ASSERT (ArgsAreMultiplesOfPatternSize && PatternIsValid &&
698+ PatternSizeIsValid,
699+ UR_RESULT_ERROR_INVALID_SIZE);
700+
701+ auto DstDevice = std::get<BufferMem>(hBuffer->Mem ).get () + offset;
702+
703+ return enqueueCommandBufferFillHelper (
704+ hCommandBuffer, &DstDevice, CU_MEMORYTYPE_DEVICE, pPattern, patternSize,
705+ size, numSyncPointsInWaitList, pSyncPointWaitList, pSyncPoint);
706+ }
707+
708+ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp (
709+ ur_exp_command_buffer_handle_t hCommandBuffer, void *pPtr,
710+ const void *pPattern, size_t patternSize, size_t size,
711+ uint32_t numSyncPointsInWaitList,
712+ const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
713+ ur_exp_command_buffer_sync_point_t *pSyncPoint) {
714+
715+ auto PatternIsValid = (pPattern != nullptr );
716+
717+ auto PatternSizeIsValid = ((patternSize & (patternSize - 1 )) == 0 ) &&
718+ (patternSize > 0 ); // is a positive power of two
719+
720+ UR_ASSERT (PatternIsValid && PatternSizeIsValid, UR_RESULT_ERROR_INVALID_SIZE);
721+ return enqueueCommandBufferFillHelper (
722+ hCommandBuffer, pPtr, CU_MEMORYTYPE_UNIFIED, pPattern, patternSize, size,
723+ numSyncPointsInWaitList, pSyncPointWaitList, pSyncPoint);
724+ }
725+
599726UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp (
600727 ur_exp_command_buffer_handle_t hCommandBuffer, ur_queue_handle_t hQueue,
601728 uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
0 commit comments