@@ -511,8 +511,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
511511
512512 UR_ASSERT (ptr, UR_RESULT_ERROR_INVALID_NULL_POINTER);
513513 UR_ASSERT (pPattern, UR_RESULT_ERROR_INVALID_NULL_POINTER);
514- UR_ASSERT (size % patternSize == 0 || patternSize > size,
515- UR_RESULT_ERROR_INVALID_SIZE);
514+ UR_ASSERT (patternSize != 0 , UR_RESULT_ERROR_INVALID_SIZE)
515+ UR_ASSERT (size != 0 , UR_RESULT_ERROR_INVALID_SIZE)
516+ UR_ASSERT (patternSize < size, UR_RESULT_ERROR_INVALID_SIZE)
517+ UR_ASSERT (size % patternSize == 0 , UR_RESULT_ERROR_INVALID_SIZE)
518+ // TODO: add check for allocation size once the query is supported
516519
517520 switch (patternSize) {
518521 case 1 :
@@ -522,33 +525,34 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
522525 const auto pattern = *static_cast <const uint16_t *>(pPattern);
523526 auto *start = reinterpret_cast <uint16_t *>(ptr);
524527 auto *end =
525- reinterpret_cast <uint16_t *>(reinterpret_cast <uint16_t *>(ptr) + size);
528+ reinterpret_cast <uint16_t *>(reinterpret_cast <uint8_t *>(ptr) + size);
526529 std::fill (start, end, pattern);
527530 break ;
528531 }
529532 case 4 : {
530533 const auto pattern = *static_cast <const uint32_t *>(pPattern);
531534 auto *start = reinterpret_cast <uint32_t *>(ptr);
532535 auto *end =
533- reinterpret_cast <uint32_t *>(reinterpret_cast <uint32_t *>(ptr) + size);
536+ reinterpret_cast <uint32_t *>(reinterpret_cast <uint8_t *>(ptr) + size);
534537 std::fill (start, end, pattern);
535538 break ;
536539 }
537540 case 8 : {
538541 const auto pattern = *static_cast <const uint64_t *>(pPattern);
539542 auto *start = reinterpret_cast <uint64_t *>(ptr);
540543 auto *end =
541- reinterpret_cast <uint64_t *>(reinterpret_cast <uint64_t *>(ptr) + size);
544+ reinterpret_cast <uint64_t *>(reinterpret_cast <uint8_t *>(ptr) + size);
542545 std::fill (start, end, pattern);
543546 break ;
544547 }
545- default :
546- for (unsigned int step{0 }; step < size; ++ step) {
547- auto *dest = reinterpret_cast < void *>( reinterpret_cast < uint8_t *>(ptr) +
548- step * patternSize );
548+ default : {
549+ for (unsigned int step{0 }; step < size; step += patternSize ) {
550+ auto *dest =
551+ reinterpret_cast < void *>( reinterpret_cast < uint8_t *>(ptr) + step);
549552 memcpy (dest, pPattern, patternSize);
550553 }
551554 }
555+ }
552556 return UR_RESULT_SUCCESS;
553557}
554558
@@ -583,7 +587,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
583587 std::ignore = phEventWaitList;
584588 std::ignore = phEvent;
585589
586- DIE_NO_IMPLEMENTATION;
590+ // TODO: properly implement USM prefetch
591+ return UR_RESULT_SUCCESS;
587592}
588593
589594UR_APIEXPORT ur_result_t UR_APICALL
@@ -595,7 +600,8 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
595600 std::ignore = advice;
596601 std::ignore = phEvent;
597602
598- DIE_NO_IMPLEMENTATION;
603+ // TODO: properly implement USM advise
604+ return UR_RESULT_SUCCESS;
599605}
600606
601607UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill2D (
0 commit comments