@@ -170,7 +170,6 @@ static ur_result_t enqueueCommandBufferFillHelper(
170170
171171 try {
172172 const size_t N = Size / PatternSize;
173- auto Value = *static_cast <const uint32_t *>(Pattern);
174173 auto DstPtr = DstType == CU_MEMORYTYPE_DEVICE
175174 ? *static_cast <CUdeviceptr *>(DstDevice)
176175 : (CUdeviceptr)DstDevice;
@@ -183,9 +182,27 @@ static ur_result_t enqueueCommandBufferFillHelper(
183182 NodeParams.elementSize = PatternSize;
184183 NodeParams.height = N;
185184 NodeParams.pitch = PatternSize;
186- NodeParams.value = Value;
187185 NodeParams.width = 1 ;
188186
187+ // pattern size in bytes
188+ switch (PatternSize) {
189+ case 1 : {
190+ auto Value = *static_cast <const uint8_t *>(Pattern);
191+ NodeParams.value = Value;
192+ break ;
193+ }
194+ case 2 : {
195+ auto Value = *static_cast <const uint16_t *>(Pattern);
196+ NodeParams.value = Value;
197+ break ;
198+ }
199+ case 4 : {
200+ auto Value = *static_cast <const uint32_t *>(Pattern);
201+ NodeParams.value = Value;
202+ break ;
203+ }
204+ }
205+
189206 UR_CHECK_ERROR (cuGraphAddMemsetNode (
190207 &GraphNode, CommandBuffer->CudaGraph , DepsList.data (),
191208 DepsList.size (), &NodeParams, CommandBuffer->Device ->getContext ()));
@@ -198,29 +215,54 @@ static ur_result_t enqueueCommandBufferFillHelper(
198215 // CUDA has no memset functions that allow setting values more than 4
199216 // bytes. UR API lets you pass an arbitrary "pattern" to the buffer
200217 // fill, which can be more than 4 bytes. We must break up the pattern
201- // into 4 byte values, and set the buffer using multiple strided calls.
202- // This means that one cuGraphAddMemsetNode call is made for every 4 bytes
203- // in the pattern.
218+ // into 1 byte values, and set the buffer using multiple strided calls.
219+ // This means that one cuGraphAddMemsetNode call is made for every 1
220+ // bytes in the pattern.
221+
222+ size_t NumberOfSteps = PatternSize / sizeof (uint8_t );
204223
205- size_t NumberOfSteps = PatternSize / sizeof (uint32_t );
224+ // Shared pointer that will point to the last node created
225+ std::shared_ptr<CUgraphNode> GraphNodePtr;
226+ // Create a new node
227+ CUgraphNode GraphNodeFirst;
228+ // Update NodeParam
229+ CUDA_MEMSET_NODE_PARAMS NodeParamsStepFirst = {};
230+ NodeParamsStepFirst.dst = DstPtr;
231+ NodeParamsStepFirst.elementSize = sizeof (uint32_t );
232+ NodeParamsStepFirst.height = Size / sizeof (uint32_t );
233+ NodeParamsStepFirst.pitch = sizeof (uint32_t );
234+ NodeParamsStepFirst.value = *static_cast <const uint32_t *>(Pattern);
235+ NodeParamsStepFirst.width = 1 ;
206236
207- // we walk up the pattern in 4-byte steps, and call cuMemset for each
208- // 4-byte chunk of the pattern.
209- for (auto Step = 0u ; Step < NumberOfSteps; ++Step) {
237+ UR_CHECK_ERROR (cuGraphAddMemsetNode (
238+ &GraphNodeFirst, CommandBuffer->CudaGraph , DepsList.data (),
239+ DepsList.size (), &NodeParamsStepFirst,
240+ CommandBuffer->Device ->getContext ()));
241+
242+ // Get sync point and register the cuNode with it.
243+ *SyncPoint = CommandBuffer->addSyncPoint (
244+ std::make_shared<CUgraphNode>(GraphNodeFirst));
245+
246+ DepsList.clear ();
247+ DepsList.push_back (GraphNodeFirst);
248+
249+ // we walk up the pattern in 1-byte steps, and call cuMemset for each
250+ // 1-byte chunk of the pattern.
251+ for (auto Step = 4u ; Step < NumberOfSteps; ++Step) {
210252 // take 4 bytes of the pattern
211- auto Value = *(static_cast <const uint32_t *>(Pattern) + Step);
253+ auto Value = *(static_cast <const uint8_t *>(Pattern) + Step);
212254
213255 // offset the pointer to the part of the buffer we want to write to
214- auto OffsetPtr = DstPtr + (Step * sizeof (uint32_t ));
256+ auto OffsetPtr = DstPtr + (Step * sizeof (uint8_t ));
215257
216258 // Create a new node
217259 CUgraphNode GraphNode;
218260 // Update NodeParam
219261 CUDA_MEMSET_NODE_PARAMS NodeParamsStep = {};
220262 NodeParamsStep.dst = (CUdeviceptr)OffsetPtr;
221- NodeParamsStep.elementSize = 4 ;
222- NodeParamsStep.height = N ;
223- NodeParamsStep.pitch = PatternSize ;
263+ NodeParamsStep.elementSize = sizeof ( uint8_t ) ;
264+ NodeParamsStep.height = Size / NumberOfSteps ;
265+ NodeParamsStep.pitch = NumberOfSteps * sizeof ( uint8_t ) ;
224266 NodeParamsStep.value = Value;
225267 NodeParamsStep.width = 1 ;
226268
@@ -229,9 +271,12 @@ static ur_result_t enqueueCommandBufferFillHelper(
229271 DepsList.size (), &NodeParamsStep,
230272 CommandBuffer->Device ->getContext ()));
231273
274+ GraphNodePtr = std::make_shared<CUgraphNode>(GraphNode);
232275 // Get sync point and register the cuNode with it.
233- *SyncPoint = CommandBuffer->addSyncPoint (
234- std::make_shared<CUgraphNode>(GraphNode));
276+ *SyncPoint = CommandBuffer->addSyncPoint (GraphNodePtr);
277+
278+ DepsList.clear ();
279+ DepsList.push_back (*GraphNodePtr.get ());
235280 }
236281 }
237282 } catch (ur_result_t Err) {
0 commit comments