Skip to content

Commit 97e8a10

Browse files
committed
Refactor to prevent extending chain too soon
1 parent b9fa06d commit 97e8a10

File tree

1 file changed

+69
-62
lines changed

1 file changed

+69
-62
lines changed

llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp

Lines changed: 69 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -946,72 +946,79 @@ std::vector<Chain> Vectorizer::splitChainByAlignment(Chain &C) {
946946
}
947947
}
948948

949-
// The vectorizer does not support non-power-of-2 element count vectors.
950-
// Extend the chain to the next power-of-2 if the current chain:
951-
// 1. Does not have a power-of-2 element count
952-
// 2. Would be legal to vectorize if the element count was extended to
953-
// the next power-of-2
954949
Chain ExtendingLoadsStores;
955-
if (NumVecElems < TargetVF && !isPowerOf2_32(NumVecElems) &&
956-
VecElemBits >= 8) {
957-
// TargetVF may be a lot higher than NumVecElems,
958-
// so only extend to the next power of 2.
959-
assert(VecElemBits % 8 == 0);
960-
unsigned VecElemBytes = VecElemBits / 8;
961-
unsigned NewNumVecElems = PowerOf2Ceil(NumVecElems);
962-
unsigned NewSizeBytes = VecElemBytes * NewNumVecElems;
963-
964-
assert(isPowerOf2_32(TargetVF) &&
965-
"TargetVF expected to be a power of 2");
966-
assert(NewNumVecElems <= TargetVF && "Should not extend past TargetVF");
967-
968-
LLVM_DEBUG(dbgs() << "LSV: attempting to extend chain of "
969-
<< NumVecElems << " "
970-
<< (IsLoadChain ? "loads" : "stores") << " to "
971-
<< NewNumVecElems << " elements\n");
972-
// Only artificially increase the chain if it would be AllowedAndFast
973-
// and if the resulting masked load/store will be legal for the target.
974-
if (accessIsAllowedAndFast(NewSizeBytes, AS, Alignment, VecElemBits) &&
975-
(IsLoadChain ? TTI.isLegalMaskedLoad(
976-
FixedVectorType::get(VecElemTy, NewNumVecElems),
977-
Alignment, AS, TTI::MaskKind::ConstantMask)
978-
: TTI.isLegalMaskedStore(
979-
FixedVectorType::get(VecElemTy, NewNumVecElems),
980-
Alignment, AS, TTI::MaskKind::ConstantMask))) {
950+
if (!accessIsAllowedAndFast(SizeBytes, AS, Alignment, VecElemBits)) {
951+
// If we have a non-power-of-2 element count, attempt to extend the
952+
// chain to the next power-of-2 if it makes the access allowed and
953+
// fast.
954+
bool AllowedAndFast = false;
955+
if (NumVecElems < TargetVF && !isPowerOf2_32(NumVecElems) &&
956+
VecElemBits >= 8) {
957+
// TargetVF may be a lot higher than NumVecElems,
958+
// so only extend to the next power of 2.
959+
assert(VecElemBits % 8 == 0);
960+
unsigned VecElemBytes = VecElemBits / 8;
961+
unsigned NewNumVecElems = PowerOf2Ceil(NumVecElems);
962+
unsigned NewSizeBytes = VecElemBytes * NewNumVecElems;
963+
964+
assert(isPowerOf2_32(TargetVF) &&
965+
"TargetVF expected to be a power of 2");
966+
assert(NewNumVecElems <= TargetVF &&
967+
"Should not extend past TargetVF");
968+
981969
LLVM_DEBUG(dbgs()
982-
<< "LSV: extending " << (IsLoadChain ? "load" : "store")
983-
<< " chain of " << NumVecElems << " "
984-
<< (IsLoadChain ? "loads" : "stores")
985-
<< " with total byte size of " << SizeBytes << " to "
986-
<< NewNumVecElems << " "
987-
<< (IsLoadChain ? "loads" : "stores")
988-
<< " with total byte size of " << NewSizeBytes
989-
<< ", TargetVF=" << TargetVF << " \n");
990-
991-
// Create (NewNumVecElems - NumVecElems) extra elements.
992-
// We are basing each extra element on CBegin, which means the offsets
993-
// should be based on SizeBytes, which represents the
994-
// offset from CBegin to the current end of the chain.
995-
unsigned ASPtrBits = DL.getIndexSizeInBits(AS);
996-
for (unsigned I = 0; I < (NewNumVecElems - NumVecElems); I++) {
997-
ChainElem NewElem = createExtraElementAfter(
998-
C[CBegin], VecElemTy,
999-
APInt(ASPtrBits, SizeBytes + I * VecElemBytes), "Extend");
1000-
ExtendingLoadsStores.push_back(NewElem);
970+
<< "LSV: attempting to extend chain of " << NumVecElems
971+
<< " " << (IsLoadChain ? "loads" : "stores") << " to "
972+
<< NewNumVecElems << " elements\n");
973+
bool IsLegalToExtend =
974+
IsLoadChain ? TTI.isLegalMaskedLoad(
975+
FixedVectorType::get(VecElemTy, NewNumVecElems),
976+
Alignment, AS, TTI::MaskKind::ConstantMask)
977+
: TTI.isLegalMaskedStore(
978+
FixedVectorType::get(VecElemTy, NewNumVecElems),
979+
Alignment, AS, TTI::MaskKind::ConstantMask);
980+
// Only artificially increase the chain if it would be AllowedAndFast
981+
// and if the resulting masked load/store will be legal for the
982+
// target.
983+
if (IsLegalToExtend &&
984+
accessIsAllowedAndFast(NewSizeBytes, AS, Alignment,
985+
VecElemBits)) {
986+
LLVM_DEBUG(dbgs()
987+
<< "LSV: extending " << (IsLoadChain ? "load" : "store")
988+
<< " chain of " << NumVecElems << " "
989+
<< (IsLoadChain ? "loads" : "stores")
990+
<< " with total byte size of " << SizeBytes << " to "
991+
<< NewNumVecElems << " "
992+
<< (IsLoadChain ? "loads" : "stores")
993+
<< " with total byte size of " << NewSizeBytes
994+
<< ", TargetVF=" << TargetVF << " \n");
995+
996+
// Create (NewNumVecElems - NumVecElems) extra elements.
997+
// We are basing each extra element on CBegin, which means the
998+
// offsets should be based on SizeBytes, which represents the offset
999+
// from CBegin to the current end of the chain.
1000+
unsigned ASPtrBits = DL.getIndexSizeInBits(AS);
1001+
for (unsigned I = 0; I < (NewNumVecElems - NumVecElems); I++) {
1002+
ChainElem NewElem = createExtraElementAfter(
1003+
C[CBegin], VecElemTy,
1004+
APInt(ASPtrBits, SizeBytes + I * VecElemBytes), "Extend");
1005+
ExtendingLoadsStores.push_back(NewElem);
1006+
}
1007+
1008+
// Update the size and number of elements for upcoming checks.
1009+
SizeBytes = NewSizeBytes;
1010+
NumVecElems = NewNumVecElems;
1011+
AllowedAndFast = true;
10011012
}
1002-
1003-
// Update the size and number of elements for upcoming checks.
1004-
SizeBytes = NewSizeBytes;
1005-
NumVecElems = NewNumVecElems;
10061013
}
1007-
}
1008-
1009-
if (!accessIsAllowedAndFast(SizeBytes, AS, Alignment, VecElemBits)) {
1010-
LLVM_DEBUG(
1011-
dbgs() << "LSV: splitChainByAlignment discarding candidate chain "
1012-
"because its alignment is not AllowedAndFast: "
1013-
<< Alignment.value() << "\n");
1014-
continue;
1014+
if (!AllowedAndFast) {
1015+
// We were not able to achieve legality by extending the chain.
1016+
LLVM_DEBUG(dbgs()
1017+
<< "LSV: splitChainByAlignment discarding candidate chain "
1018+
"because its alignment is not AllowedAndFast: "
1019+
<< Alignment.value() << "\n");
1020+
continue;
1021+
}
10151022
}
10161023

10171024
if ((IsLoadChain &&

0 commit comments

Comments
 (0)