Skip to content

Commit 4113e63

Browse files
committed
Address review feedback
1 parent cae6020 commit 4113e63

File tree

2 files changed

+70
-28
lines changed

2 files changed

+70
-28
lines changed

llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -283,13 +283,15 @@ class Vectorizer {
283283
bool runOnChain(Chain &C);
284284

285285
/// Splits the chain into subchains of instructions which read/write a
286-
/// contiguous block of memory. Discards any length-1 subchains (because
287-
/// there's nothing to vectorize in there).
286+
/// contiguous block of memory. Discards any length-1 subchains (because
287+
/// there's nothing to vectorize in there). Also attempts to fill gaps with
288+
/// "extra" elements to artificially make chains contiguous in some cases.
288289
std::vector<Chain> splitChainByContiguity(Chain &C);
289290

290291
/// Splits the chain into subchains where it's safe to hoist loads up to the
291292
/// beginning of the sub-chain and it's safe to sink loads up to the end of
292-
/// the sub-chain. Discards any length-1 subchains.
293+
/// the sub-chain. Discards any length-1 subchains. Also attempts to extend
294+
/// non-power-of-two chains by adding "extra" elements in some cases.
293295
std::vector<Chain> splitChainByMayAliasInstrs(Chain &C);
294296

295297
/// Splits the chain into subchains that make legal, aligned accesses.
@@ -730,14 +732,15 @@ std::vector<Chain> Vectorizer::splitChainByContiguity(Chain &C) {
730732
// which could cancel out the benefits of reducing number of load/stores.
731733
if (TryFillGaps &&
732734
SzBits == DL.getTypeSizeInBits(getLoadStoreType(It->Inst))) {
733-
APInt OffsetOfGapStart = Prev.OffsetFromLeader + PrevSzBytes;
734-
APInt GapSzBytes = It->OffsetFromLeader - OffsetOfGapStart;
735+
APInt OffsetFromLeaderOfGapStart = Prev.OffsetFromLeader + PrevSzBytes;
736+
APInt GapSzBytes = It->OffsetFromLeader - OffsetFromLeaderOfGapStart;
735737
if (GapSzBytes == PrevSzBytes) {
736738
// There is a single gap between Prev and Curr, create one extra element
737739
ChainElem NewElem = createExtraElementAfter(
738740
Prev, PrevSzBytes, "GapFill",
739-
commonAlignment(LeaderOfChainAlign,
740-
OffsetOfGapStart.abs().getLimitedValue()));
741+
commonAlignment(
742+
LeaderOfChainAlign,
743+
OffsetFromLeaderOfGapStart.abs().getLimitedValue()));
741744
CurChain.push_back(NewElem);
742745
CurChain.push_back(*It);
743746
continue;
@@ -748,13 +751,15 @@ std::vector<Chain> Vectorizer::splitChainByContiguity(Chain &C) {
748751
if ((GapSzBytes == 2 * PrevSzBytes) && (CurChain.size() % 4 == 1)) {
749752
ChainElem NewElem1 = createExtraElementAfter(
750753
Prev, PrevSzBytes, "GapFill",
751-
commonAlignment(LeaderOfChainAlign,
752-
OffsetOfGapStart.abs().getLimitedValue()));
753-
ChainElem NewElem2 = createExtraElementAfter(
754-
NewElem1, PrevSzBytes, "GapFill",
755754
commonAlignment(
756755
LeaderOfChainAlign,
757-
(OffsetOfGapStart + PrevSzBytes).abs().getLimitedValue()));
756+
OffsetFromLeaderOfGapStart.abs().getLimitedValue()));
757+
ChainElem NewElem2 = createExtraElementAfter(
758+
NewElem1, PrevSzBytes, "GapFill",
759+
commonAlignment(LeaderOfChainAlign,
760+
(OffsetFromLeaderOfGapStart + PrevSzBytes)
761+
.abs()
762+
.getLimitedValue()));
758763
CurChain.push_back(NewElem1);
759764
CurChain.push_back(NewElem2);
760765
CurChain.push_back(*It);
@@ -920,9 +925,14 @@ std::vector<Chain> Vectorizer::splitChainByAlignment(Chain &C) {
920925
}
921926
}
922927

923-
// Attempt to extend non-power-of-2 chains to the next power of 2.
928+
// The vectorizer does not support non-power-of-2 element count vectors.
929+
// Extend the chain to the next power-of-2 if the current chain:
930+
// 1. Does not have a power-of-2 element count
931+
// 2. Would be legal to vectorize if the element count was extended to
932+
// the next power-of-2
924933
Chain ExtendingLoadsStores;
925-
if (NumVecElems < TargetVF && NumVecElems % 2 != 0 && VecElemBits >= 8) {
934+
if (NumVecElems < TargetVF && !isPowerOf2_32(NumVecElems) &&
935+
VecElemBits >= 8 && isPowerOf2_32(TargetVF)) {
926936
// TargetVF may be a lot higher than NumVecElems,
927937
// so only extend to the next power of 2.
928938
assert(VecElemBits % 8 == 0);
@@ -936,10 +946,8 @@ std::vector<Chain> Vectorizer::splitChainByAlignment(Chain &C) {
936946
<< NumVecElems << " "
937947
<< (IsLoadChain ? "loads" : "stores") << " to "
938948
<< NewNumVecElems << " elements\n");
939-
// Do not artificially increase the chain if it becomes misaligned or if
940-
// the associated masked load/store is not legal, otherwise we may
941-
// unnecessarily split the chain when the target actually supports
942-
// non-pow2 VF.
949+
// Only artificially increase the chain if it would be AllowedAndFast
950+
// and if the resulting masked load/store will be legal for the target.
943951
if (accessIsAllowedAndFast(NewSizeBytes, AS, Alignment, VecElemBits) &&
944952
(IsLoadChain ? TTI.isLegalMaskedLoad(
945953
FixedVectorType::get(VecElemTy, NewNumVecElems),
@@ -1039,15 +1047,14 @@ bool Vectorizer::vectorizeChain(Chain &C) {
10391047
if (C.size() < 2)
10401048
return false;
10411049

1050+
bool ChainContainsExtraLoadsStores = llvm::any_of(
1051+
C, [this](const ChainElem &E) { return ExtraElements.contains(E.Inst); });
1052+
10421053
// If we are left with a two-element chain, and one of the elements is an
10431054
// extra element, we don't want to vectorize
1044-
if (C.size() == 2 &&
1045-
(ExtraElements.contains(C[0].Inst) || ExtraElements.contains(C[1].Inst)))
1055+
if (C.size() == 2 && ChainContainsExtraLoadsStores)
10461056
return false;
10471057

1048-
bool ChainContainsExtraLoadsStores = llvm::any_of(
1049-
C, [this](const ChainElem &E) { return ExtraElements.contains(E.Inst); });
1050-
10511058
sortChainInOffsetOrder(C);
10521059

10531060
LLVM_DEBUG({
@@ -1847,8 +1854,11 @@ std::optional<APInt> Vectorizer::getConstantOffset(Value *PtrA, Value *PtrB,
18471854
bool Vectorizer::accessIsAllowedAndFast(unsigned SizeBytes, unsigned AS,
18481855
Align Alignment,
18491856
unsigned VecElemBits) const {
1857+
// Aligned vector accesses are ALWAYS faster than element-wise accesses.
18501858
if (Alignment.value() % SizeBytes == 0)
18511859
return true;
1860+
1861+
// Element-wise access *might* be faster than misaligned vector accesses.
18521862
unsigned VectorizedSpeed = 0;
18531863
bool AllowsMisaligned = TTI.allowsMisalignedMemoryAccesses(
18541864
F.getContext(), SizeBytes * 8, AS, Alignment, &VectorizedSpeed);

llvm/test/Transforms/LoadStoreVectorizer/NVPTX/extend-chain.ll

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
;; code. Alignment and other requirement for vectorization should
77
;; still be met.
88

9-
define void @load3to4(ptr %p) #0 {
9+
define void @load3to4(ptr %p) {
1010
; CHECK-LABEL: define void @load3to4(
1111
; CHECK-SAME: ptr [[P:%.*]]) {
1212
; CHECK-NEXT: [[P_0:%.*]] = getelementptr i32, ptr [[P]], i32 0
@@ -28,7 +28,7 @@ define void @load3to4(ptr %p) #0 {
2828
ret void
2929
}
3030

31-
define void @load5to8(ptr %p) #0 {
31+
define void @load5to8(ptr %p) {
3232
; CHECK-LABEL: define void @load5to8(
3333
; CHECK-SAME: ptr [[P:%.*]]) {
3434
; CHECK-NEXT: [[P_0:%.*]] = getelementptr i16, ptr [[P]], i32 0
@@ -52,13 +52,45 @@ define void @load5to8(ptr %p) #0 {
5252
%v0 = load i16, ptr %p.0, align 16
5353
%v1 = load i16, ptr %p.1, align 2
5454
%v2 = load i16, ptr %p.2, align 4
55-
%v3 = load i16, ptr %p.3, align 8
56-
%v4 = load i16, ptr %p.4, align 2
55+
%v3 = load i16, ptr %p.3, align 2
56+
%v4 = load i16, ptr %p.4, align 8
5757

5858
ret void
5959
}
6060

61-
define void @load3to4_unaligned(ptr %p) #0 {
61+
define void @load6to8(ptr %p) {
62+
; CHECK-LABEL: define void @load6to8(
63+
; CHECK-SAME: ptr [[P:%.*]]) {
64+
; CHECK-NEXT: [[P_0:%.*]] = getelementptr i16, ptr [[P]], i32 0
65+
; CHECK-NEXT: [[TMP1:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0(ptr align 16 [[P_0]], <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 false, i1 false>, <8 x i16> poison)
66+
; CHECK-NEXT: [[V05:%.*]] = extractelement <8 x i16> [[TMP1]], i32 0
67+
; CHECK-NEXT: [[V16:%.*]] = extractelement <8 x i16> [[TMP1]], i32 1
68+
; CHECK-NEXT: [[V27:%.*]] = extractelement <8 x i16> [[TMP1]], i32 2
69+
; CHECK-NEXT: [[V38:%.*]] = extractelement <8 x i16> [[TMP1]], i32 3
70+
; CHECK-NEXT: [[V49:%.*]] = extractelement <8 x i16> [[TMP1]], i32 4
71+
; CHECK-NEXT: [[EXTEND10:%.*]] = extractelement <8 x i16> [[TMP1]], i32 5
72+
; CHECK-NEXT: [[EXTEND211:%.*]] = extractelement <8 x i16> [[TMP1]], i32 6
73+
; CHECK-NEXT: [[EXTEND412:%.*]] = extractelement <8 x i16> [[TMP1]], i32 7
74+
; CHECK-NEXT: ret void
75+
;
76+
%p.0 = getelementptr i16, ptr %p, i32 0
77+
%p.1 = getelementptr i16, ptr %p, i32 1
78+
%p.2 = getelementptr i16, ptr %p, i32 2
79+
%p.3 = getelementptr i16, ptr %p, i32 3
80+
%p.4 = getelementptr i16, ptr %p, i32 4
81+
%p.5 = getelementptr i16, ptr %p, i32 5
82+
83+
%v0 = load i16, ptr %p.0, align 16
84+
%v1 = load i16, ptr %p.1, align 2
85+
%v2 = load i16, ptr %p.2, align 4
86+
%v3 = load i16, ptr %p.3, align 2
87+
%v4 = load i16, ptr %p.4, align 8
88+
%v5 = load i16, ptr %p.5, align 2
89+
90+
ret void
91+
}
92+
93+
define void @load3to4_unaligned(ptr %p) {
6294
; CHECK-LABEL: define void @load3to4_unaligned(
6395
; CHECK-SAME: ptr [[P:%.*]]) {
6496
; CHECK-NEXT: [[P_0:%.*]] = getelementptr i32, ptr [[P]], i32 0

0 commit comments

Comments
 (0)