@@ -374,8 +374,41 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
374
374
UR_ASSERT (size % patternSize == 0 || patternSize > size,
375
375
UR_RESULT_ERROR_INVALID_SIZE);
376
376
377
- memset (ptr, *static_cast <const uint8_t *>(pPattern), size * patternSize);
378
-
377
+ switch (patternSize) {
378
+ case 1 :
379
+ memset (ptr, *static_cast <const uint8_t *>(pPattern), size * patternSize);
380
+ break ;
381
+ case 2 : {
382
+ const auto pattern = *static_cast <const uint16_t *>(pPattern);
383
+ auto *start = reinterpret_cast <uint16_t *>(ptr);
384
+ auto *end =
385
+ reinterpret_cast <uint16_t *>(reinterpret_cast <uint16_t *>(ptr) + size);
386
+ std::fill (start, end, pattern);
387
+ break ;
388
+ }
389
+ case 4 : {
390
+ const auto pattern = *static_cast <const uint32_t *>(pPattern);
391
+ auto *start = reinterpret_cast <uint32_t *>(ptr);
392
+ auto *end =
393
+ reinterpret_cast <uint32_t *>(reinterpret_cast <uint32_t *>(ptr) + size);
394
+ std::fill (start, end, pattern);
395
+ break ;
396
+ }
397
+ case 8 : {
398
+ const auto pattern = *static_cast <const uint64_t *>(pPattern);
399
+ auto *start = reinterpret_cast <uint64_t *>(ptr);
400
+ auto *end =
401
+ reinterpret_cast <uint64_t *>(reinterpret_cast <uint64_t *>(ptr) + size);
402
+ std::fill (start, end, pattern);
403
+ break ;
404
+ }
405
+ default :
406
+ for (unsigned int step{0 }; step < size; ++step) {
407
+ auto *dest = reinterpret_cast <void *>(reinterpret_cast <uint8_t *>(ptr) +
408
+ step * patternSize);
409
+ memcpy (dest, pPattern, patternSize);
410
+ }
411
+ }
379
412
return UR_RESULT_SUCCESS;
380
413
}
381
414
0 commit comments