@@ -374,8 +374,41 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
374374 UR_ASSERT (size % patternSize == 0 || patternSize > size,
375375 UR_RESULT_ERROR_INVALID_SIZE);
376376
377- memset (ptr, *static_cast <const uint8_t *>(pPattern), size * patternSize);
378-
377+ switch (patternSize) {
378+ case 1 :
379+ memset (ptr, *static_cast <const uint8_t *>(pPattern), size * patternSize);
380+ break ;
381+ case 2 : {
382+ const auto pattern = *static_cast <const uint16_t *>(pPattern);
383+ auto *start = reinterpret_cast <uint16_t *>(ptr);
384+ auto *end =
385+ reinterpret_cast <uint16_t *>(reinterpret_cast <uint16_t *>(ptr) + size);
386+ std::fill (start, end, pattern);
387+ break ;
388+ }
389+ case 4 : {
390+ const auto pattern = *static_cast <const uint32_t *>(pPattern);
391+ auto *start = reinterpret_cast <uint32_t *>(ptr);
392+ auto *end =
393+ reinterpret_cast <uint32_t *>(reinterpret_cast <uint32_t *>(ptr) + size);
394+ std::fill (start, end, pattern);
395+ break ;
396+ }
397+ case 8 : {
398+ const auto pattern = *static_cast <const uint64_t *>(pPattern);
399+ auto *start = reinterpret_cast <uint64_t *>(ptr);
400+ auto *end =
401+ reinterpret_cast <uint64_t *>(reinterpret_cast <uint64_t *>(ptr) + size);
402+ std::fill (start, end, pattern);
403+ break ;
404+ }
405+ default :
406+ for (unsigned int step{0 }; step < size; ++step) {
407+ auto *dest = reinterpret_cast <void *>(reinterpret_cast <uint8_t *>(ptr) +
408+ step * patternSize);
409+ memcpy (dest, pPattern, patternSize);
410+ }
411+ }
379412 return UR_RESULT_SUCCESS;
380413}
381414
0 commit comments