@@ -504,8 +504,41 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
504504 UR_ASSERT (size % patternSize == 0 || patternSize > size,
505505 UR_RESULT_ERROR_INVALID_SIZE);
506506
507- memset (ptr, *static_cast <const uint8_t *>(pPattern), size * patternSize);
508-
507+ switch (patternSize) {
508+ case 1 :
509+ memset (ptr, *static_cast <const uint8_t *>(pPattern), size * patternSize);
510+ break ;
511+ case 2 : {
512+ const auto pattern = *static_cast <const uint16_t *>(pPattern);
513+ auto *start = reinterpret_cast <uint16_t *>(ptr);
514+ auto *end =
515+ reinterpret_cast <uint16_t *>(reinterpret_cast <uint16_t *>(ptr) + size);
516+ std::fill (start, end, pattern);
517+ break ;
518+ }
519+ case 4 : {
520+ const auto pattern = *static_cast <const uint32_t *>(pPattern);
521+ auto *start = reinterpret_cast <uint32_t *>(ptr);
522+ auto *end =
523+ reinterpret_cast <uint32_t *>(reinterpret_cast <uint32_t *>(ptr) + size);
524+ std::fill (start, end, pattern);
525+ break ;
526+ }
527+ case 8 : {
528+ const auto pattern = *static_cast <const uint64_t *>(pPattern);
529+ auto *start = reinterpret_cast <uint64_t *>(ptr);
530+ auto *end =
531+ reinterpret_cast <uint64_t *>(reinterpret_cast <uint64_t *>(ptr) + size);
532+ std::fill (start, end, pattern);
533+ break ;
534+ }
535+ default :
536+ for (unsigned int step{0 }; step < size; ++step) {
537+ auto *dest = reinterpret_cast <void *>(reinterpret_cast <uint8_t *>(ptr) +
538+ step * patternSize);
539+ memcpy (dest, pPattern, patternSize);
540+ }
541+ }
509542 return UR_RESULT_SUCCESS;
510543}
511544
0 commit comments