Skip to content

Commit 5e99670

Browse files
committed
Fix incorrect outputs and improve performance of commonMemSetLargePattern
Change the implementation of commonMemSetLargePattern to use the largest pattern word size supported by the backend into which the pattern can be divided. That is, use 4-byte words if the pattern size is a multiple of 4, 2-byte words for even sizes and 1-byte words for odd sizes. Keep the idea of filling the entire destination region with the first word, and only start strided fill from the second, but implement it correctly. The previous implementation produced incorrect results for any pattern size which wasn't a multiple of 4. Add a new optimisation skipping the strided fills completely if the pattern is equal to the first word repeated throughout. This is most commonly the case for a pattern of all zeros, but other cases are possible.
1 parent a9c7aef commit 5e99670

File tree

1 file changed

+50
-25
lines changed

1 file changed

+50
-25
lines changed

source/adapters/cuda/enqueue.cpp

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -953,35 +953,60 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
953953

954954
// CUDA has no memset functions that allow setting values more than 4 bytes. UR
955955
// API lets you pass an arbitrary "pattern" to the buffer fill, which can be
956-
// more than 4 bytes. We must break up the pattern into 1 byte values, and set
957-
// the buffer using multiple strided calls. The first 4 patterns are set using
958-
// cuMemsetD32Async then all subsequent 1 byte patterns are set using
959-
// cuMemset2DAsync which is called for each pattern.
956+
// more than 4 bytes. We must break up the pattern into 1, 2 or 4-byte values
957+
// and set the buffer using multiple strided calls.
960958
ur_result_t commonMemSetLargePattern(CUstream Stream, uint32_t PatternSize,
961959
size_t Size, const void *pPattern,
962960
CUdeviceptr Ptr) {
963-
// Calculate the number of patterns, stride, number of times the pattern
964-
// needs to be applied, and the number of times the first 32 bit pattern
965-
// needs to be applied.
966-
auto NumberOfSteps = PatternSize / sizeof(uint8_t);
967-
auto Pitch = NumberOfSteps * sizeof(uint8_t);
968-
auto Height = Size / NumberOfSteps;
969-
auto Count32 = Size / sizeof(uint32_t);
970-
971-
// Get 4-byte chunk of the pattern and call cuMemsetD32Async
972-
auto Value = *(static_cast<const uint32_t *>(pPattern));
973-
UR_CHECK_ERROR(cuMemsetD32Async(Ptr, Value, Count32, Stream));
974-
for (auto step = 4u; step < NumberOfSteps; ++step) {
975-
// take 1 byte of the pattern
976-
Value = *(static_cast<const uint8_t *>(pPattern) + step);
977-
978-
// offset the pointer to the part of the buffer we want to write to
979-
auto OffsetPtr = Ptr + (step * sizeof(uint8_t));
980-
981-
// set all of the pattern chunks
982-
UR_CHECK_ERROR(cuMemsetD2D8Async(OffsetPtr, Pitch, Value, sizeof(uint8_t),
983-
Height, Stream));
961+
// Find the largest supported word size into which the pattern can be divided
962+
auto BackendWordSize = PatternSize%4u==0u ? 4u : PatternSize%2u==0u ? 2u : 1u;
963+
964+
// Calculate the number of words in the pattern, the stride, and the number of
965+
// times the pattern needs to be applied
966+
auto NumberOfSteps = PatternSize / BackendWordSize;
967+
auto Pitch = NumberOfSteps * BackendWordSize;
968+
auto Height = Size / PatternSize;
969+
970+
// Same implementation works for any pattern word type (uint8_t, uint16_t, uint32_t)
971+
auto memsetImpl = [BackendWordSize, NumberOfSteps, Pitch, Height, Size, Ptr, &Stream](const auto* pPatternWords, auto&& continuousMemset, auto&& stridedMemset){
972+
// If the pattern is 1 word or the first word is repeated throughout, a fast
973+
// continuous fill can be used without the need for slower strided fills
974+
bool UseOnlyFirstValue{true};
975+
for (auto Step{1u}; (Step < NumberOfSteps) && UseOnlyFirstValue; ++Step) {
976+
if (*(pPatternWords + Step) != *pPatternWords) {
977+
UseOnlyFirstValue=false;
978+
}
979+
}
980+
auto OptimizedNumberOfSteps{UseOnlyFirstValue ? 1u : NumberOfSteps};
981+
982+
// Fill the pattern in steps of BackendWordSize bytes. Use a continuous
983+
// fill in the first step because it's faster than a strided fill. Then,
984+
// overwrite the other values in subsequent steps.
985+
for (auto Step{0u}; Step < OptimizedNumberOfSteps; ++Step) {
986+
if (Step==0) {
987+
UR_CHECK_ERROR(continuousMemset(Ptr, *(pPatternWords), Size / BackendWordSize, Stream));
988+
} else {
989+
UR_CHECK_ERROR(stridedMemset(Ptr + Step * BackendWordSize, Pitch, *(pPatternWords + Step), 1u, Height, Stream));
990+
}
991+
}
992+
};
993+
994+
// Apply the implementation to the chosen pattern word type
995+
switch (BackendWordSize) {
996+
case 4u: {
997+
memsetImpl(static_cast<const uint32_t *>(pPattern), cuMemsetD32Async, cuMemsetD2D32Async);
998+
break;
999+
}
1000+
case 2u: {
1001+
memsetImpl(static_cast<const uint16_t *>(pPattern), cuMemsetD16Async, cuMemsetD2D16Async);
1002+
break;
1003+
}
1004+
default: {
1005+
memsetImpl(static_cast<const uint8_t *>(pPattern), cuMemsetD8Async, cuMemsetD2D8Async);
1006+
break;
1007+
}
9841008
}
1009+
9851010
return UR_RESULT_SUCCESS;
9861011
}
9871012

0 commit comments

Comments
 (0)