@@ -504,8 +504,41 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
504
504
UR_ASSERT (size % patternSize == 0 || patternSize > size,
505
505
UR_RESULT_ERROR_INVALID_SIZE);
506
506
507
- memset (ptr, *static_cast <const uint8_t *>(pPattern), size * patternSize);
508
-
507
+ switch (patternSize) {
508
+ case 1 :
509
+ memset (ptr, *static_cast <const uint8_t *>(pPattern), size * patternSize);
510
+ break ;
511
+ case 2 : {
512
+ const auto pattern = *static_cast <const uint16_t *>(pPattern);
513
+ auto *start = reinterpret_cast <uint16_t *>(ptr);
514
+ auto *end =
515
+ reinterpret_cast <uint16_t *>(reinterpret_cast <uint16_t *>(ptr) + size);
516
+ std::fill (start, end, pattern);
517
+ break ;
518
+ }
519
+ case 4 : {
520
+ const auto pattern = *static_cast <const uint32_t *>(pPattern);
521
+ auto *start = reinterpret_cast <uint32_t *>(ptr);
522
+ auto *end =
523
+ reinterpret_cast <uint32_t *>(reinterpret_cast <uint32_t *>(ptr) + size);
524
+ std::fill (start, end, pattern);
525
+ break ;
526
+ }
527
+ case 8 : {
528
+ const auto pattern = *static_cast <const uint64_t *>(pPattern);
529
+ auto *start = reinterpret_cast <uint64_t *>(ptr);
530
+ auto *end =
531
+ reinterpret_cast <uint64_t *>(reinterpret_cast <uint64_t *>(ptr) + size);
532
+ std::fill (start, end, pattern);
533
+ break ;
534
+ }
535
+ default :
536
+ for (unsigned int step{0 }; step < size; ++step) {
537
+ auto *dest = reinterpret_cast <void *>(reinterpret_cast <uint8_t *>(ptr) +
538
+ step * patternSize);
539
+ memcpy (dest, pPattern, patternSize);
540
+ }
541
+ }
509
542
return UR_RESULT_SUCCESS;
510
543
}
511
544
0 commit comments