@@ -712,25 +712,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
712712
713713static inline void memsetRemainPattern (hipStream_t Stream, uint32_t PatternSize,
714714 size_t Size, const void *pPattern,
715- hipDeviceptr_t Ptr) {
715+ hipDeviceptr_t Ptr,
716+ uint32_t StartOffset) {
717+ // Calculate the number of times the pattern needs to be applied
718+ auto Height = Size / PatternSize;
716719
717- // Calculate the number of patterns, stride and the number of times the
718- // pattern needs to be applied.
719- auto NumberOfSteps = PatternSize / sizeof (uint8_t );
720- auto Pitch = NumberOfSteps * sizeof (uint8_t );
721- auto Height = Size / NumberOfSteps;
722-
723- for (auto step = 4u ; step < NumberOfSteps; ++step) {
720+ for (auto step = StartOffset; step < PatternSize; ++step) {
724721 // take 1 byte of the pattern
725722 auto Value = *(static_cast <const uint8_t *>(pPattern) + step);
726723
727724 // offset the pointer to the part of the buffer we want to write to
728- auto OffsetPtr = reinterpret_cast < void *>( reinterpret_cast < uint8_t *>(Ptr) +
729- (step * sizeof ( uint8_t )) );
725+ auto OffsetPtr =
726+ reinterpret_cast < void *>( reinterpret_cast < uint8_t *>(Ptr) + step );
730727
731728 // set all of the pattern chunks
732- UR_CHECK_ERROR (hipMemset2DAsync (OffsetPtr, Pitch, Value, sizeof ( uint8_t ),
733- Height, Stream));
729+ UR_CHECK_ERROR (
730+ hipMemset2DAsync (OffsetPtr, PatternSize, Value, 1u , Height, Stream));
734731 }
735732}
736733
@@ -743,11 +740,55 @@ static inline void memsetRemainPattern(hipStream_t Stream, uint32_t PatternSize,
743740ur_result_t commonMemSetLargePattern (hipStream_t Stream, uint32_t PatternSize,
744741 size_t Size, const void *pPattern,
745742 hipDeviceptr_t Ptr) {
743+ // Find the largest supported word size into which the pattern can be divided
744+ auto BackendWordSize = PatternSize % 4u == 0u ? 4u
745+ : PatternSize % 2u == 0u ? 2u
746+ : 1u ;
747+
748+ // Calculate the number of patterns
749+ auto NumberOfSteps = PatternSize / BackendWordSize;
750+
751+ // If the pattern is 1 word or the first word is repeated throughout, a fast
752+ // continuous fill can be used without the need for slower strided fills
753+ bool UseOnlyFirstValue{true };
754+ auto checkIfFirstWordRepeats = [&UseOnlyFirstValue,
755+ NumberOfSteps](const auto *pPatternWords) {
756+ for (auto Step{1u }; (Step < NumberOfSteps) && UseOnlyFirstValue; ++Step) {
757+ if (*(pPatternWords + Step) != *pPatternWords) {
758+ UseOnlyFirstValue = false ;
759+ }
760+ }
761+ };
746762
747- // Get 4-byte chunk of the pattern and call hipMemsetD32Async
748- auto Count32 = Size / sizeof (uint32_t );
749- auto Value = *(static_cast <const uint32_t *>(pPattern));
750- UR_CHECK_ERROR (hipMemsetD32Async (Ptr, Value, Count32, Stream));
763+ // Use a continuous fill for the first word in the pattern because it's faster
764+ // than a strided fill. Then, overwrite the other values in subsequent steps.
765+ switch (BackendWordSize) {
766+ case 4u : {
767+ auto *pPatternWords = static_cast <const uint32_t *>(pPattern);
768+ checkIfFirstWordRepeats (pPatternWords);
769+ UR_CHECK_ERROR (
770+ hipMemsetD32Async (Ptr, *pPatternWords, Size / BackendWordSize, Stream));
771+ break ;
772+ }
773+ case 2u : {
774+ auto *pPatternWords = static_cast <const uint16_t *>(pPattern);
775+ checkIfFirstWordRepeats (pPatternWords);
776+ UR_CHECK_ERROR (
777+ hipMemsetD16Async (Ptr, *pPatternWords, Size / BackendWordSize, Stream));
778+ break ;
779+ }
780+ default : {
781+ auto *pPatternWords = static_cast <const uint8_t *>(pPattern);
782+ checkIfFirstWordRepeats (pPatternWords);
783+ UR_CHECK_ERROR (
784+ hipMemsetD8Async (Ptr, *pPatternWords, Size / BackendWordSize, Stream));
785+ break ;
786+ }
787+ }
788+
789+ if (UseOnlyFirstValue) {
790+ return UR_RESULT_SUCCESS;
791+ }
751792
752793 // There is a bug in ROCm prior to 6.0.0 version which causes hipMemset2D
753794 // to behave incorrectly when acting on host pinned memory.
@@ -761,7 +802,7 @@ ur_result_t commonMemSetLargePattern(hipStream_t Stream, uint32_t PatternSize,
761802 // we need to check that isManaged attribute is false.
762803 if (ptrAttribs.hostPointer && !ptrAttribs.isManaged ) {
763804 const auto NumOfCopySteps = Size / PatternSize;
764- const auto Offset = sizeof ( uint32_t ) ;
805+ const auto Offset = BackendWordSize ;
765806 const auto LeftPatternSize = PatternSize - Offset;
766807 const auto OffsetPatternPtr = reinterpret_cast <const void *>(
767808 reinterpret_cast <const uint8_t *>(pPattern) + Offset);
@@ -776,10 +817,12 @@ ur_result_t commonMemSetLargePattern(hipStream_t Stream, uint32_t PatternSize,
776817 Stream));
777818 }
778819 } else {
779- memsetRemainPattern (Stream, PatternSize, Size, pPattern, Ptr);
820+ memsetRemainPattern (Stream, PatternSize, Size, pPattern, Ptr,
821+ BackendWordSize);
780822 }
781823#else
782- memsetRemainPattern (Stream, PatternSize, Size, pPattern, Ptr);
824+ memsetRemainPattern (Stream, PatternSize, Size, pPattern, Ptr,
825+ BackendWordSize);
783826#endif
784827 return UR_RESULT_SUCCESS;
785828}
0 commit comments