Skip to content

Commit a04b062

Browse files
committed
Extended native cpu fill to bigger patterns than 1 byte
1 parent 03b7148 commit a04b062

File tree

1 file changed

+35
-2
lines changed

1 file changed

+35
-2
lines changed

source/adapters/native_cpu/enqueue.cpp

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,8 +374,41 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
374374
UR_ASSERT(size % patternSize == 0 || patternSize > size,
375375
UR_RESULT_ERROR_INVALID_SIZE);
376376

377-
memset(ptr, *static_cast<const uint8_t *>(pPattern), size * patternSize);
378-
377+
switch (patternSize) {
378+
case 1:
379+
memset(ptr, *static_cast<const uint8_t *>(pPattern), size * patternSize);
380+
break;
381+
case 2: {
382+
const auto pattern = *static_cast<const uint16_t *>(pPattern);
383+
auto *start = reinterpret_cast<uint16_t *>(ptr);
384+
auto *end =
385+
reinterpret_cast<uint16_t *>(reinterpret_cast<uint16_t *>(ptr) + size);
386+
std::fill(start, end, pattern);
387+
break;
388+
}
389+
case 4: {
390+
const auto pattern = *static_cast<const uint32_t *>(pPattern);
391+
auto *start = reinterpret_cast<uint32_t *>(ptr);
392+
auto *end =
393+
reinterpret_cast<uint32_t *>(reinterpret_cast<uint32_t *>(ptr) + size);
394+
std::fill(start, end, pattern);
395+
break;
396+
}
397+
case 8: {
398+
const auto pattern = *static_cast<const uint64_t *>(pPattern);
399+
auto *start = reinterpret_cast<uint64_t *>(ptr);
400+
auto *end =
401+
reinterpret_cast<uint64_t *>(reinterpret_cast<uint64_t *>(ptr) + size);
402+
std::fill(start, end, pattern);
403+
break;
404+
}
405+
default:
406+
for (unsigned int step{0}; step < size; ++step) {
407+
auto *dest = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(ptr) +
408+
step * patternSize);
409+
memcpy(dest, pPattern, patternSize);
410+
}
411+
}
379412
return UR_RESULT_SUCCESS;
380413
}
381414

0 commit comments

Comments
 (0)