Skip to content

Commit 5abd368

Browse files
committed
LoopVectorize: vectorize decreasing integer IV in select-cmp
Extend the idea in llvm#67812 to support vectorizion of decreasing IV in select-cmp patterns. llvm#67812 enabled vectorization of the following example: long src[20000] = {4, 5, 2}; long r = 331; for (long i = 0; i < 20000; i++) { if (src[i] > 3) r = i; } return r; This patch extends the above idea to also vectorize: long src[20000] = {4, 5, 2}; long r = 331; for (long i = 20000 - 1; i >= 0; i--) { if (src[i] > 3) r = i; } return r;
1 parent ba97834 commit 5abd368

File tree

8 files changed

+333
-93
lines changed

8 files changed

+333
-93
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,18 @@ enum class RecurKind {
5454
///< loop invariant, and both x and y are integer type.
5555
FAnyOf, ///< Any_of reduction with select(fcmp(),x,y) where one of (x,y) is
5656
///< loop invariant, and both x and y are integer type.
57-
IFindLastIV, ///< FindLast reduction with select(icmp(),x,y) where one of
58-
///< (x,y) is increasing loop induction PHI, and both x and y are
59-
///< integer type.
60-
FFindLastIV ///< FindLast reduction with select(fcmp(),x,y) where one of (x,y)
61-
///< is increasing loop induction PHI, and both x and y are
62-
///< integer type.
57+
IFindLastIncIV, ///< FindLast reduction with select(icmp(),x,y) where one of
58+
///< (x,y) is increasing loop induction PHI, and both x and y
59+
///< are integer type.
60+
FFindLastIncIV, ///< FindLast reduction with select(fcmp(),x,y) where one of
61+
///< (x,y) is increasing loop induction PHI, and both x and y
62+
///< are integer type.
63+
IFindLastDecIV, ///< FindLast reduction with select(icmp(),x,y) where one of
64+
///< (x,y) is decreasing loop induction PHI, and both x and y
65+
///< are integer type.
66+
FFindLastDecIV ///< FindLast reduction with select(fcmp(),x,y) where one of
67+
///< (x,y) is decreasing loop induction PHI, and both x and y
68+
///< are integer type.
6369
// TODO: Any_of and FindLast reduction need not be restricted to integer type
6470
// only.
6571
};
@@ -163,10 +169,8 @@ class RecurrenceDescriptor {
163169
/// Returns a struct describing whether the instruction is either a
164170
/// Select(ICmp(A, B), X, Y), or
165171
/// Select(FCmp(A, B), X, Y)
166-
/// where one of (X, Y) is an increasing loop induction variable, and the
167-
/// other is a PHI value.
168-
// TODO: FindLast does not need be restricted to increasing loop induction
169-
// variables.
172+
/// where one of (X, Y) is an increasing/decreasing loop induction variable,
173+
/// and the other is a PHI value.
170174
static InstDesc isFindLastIVPattern(Loop *Loop, PHINode *OrigPhi,
171175
Instruction *I, ScalarEvolution *SE);
172176

@@ -259,9 +263,13 @@ class RecurrenceDescriptor {
259263
}
260264

261265
/// Returns true if the recurrence kind is of the form
262-
/// select(cmp(),x,y) where one of (x,y) is increasing loop induction.
266+
/// select(cmp(),x,y) where one of (x,y) is increasing/decreasing loop
267+
/// induction.
263268
static bool isFindLastIVRecurrenceKind(RecurKind Kind) {
264-
return Kind == RecurKind::IFindLastIV || Kind == RecurKind::FFindLastIV;
269+
return Kind == RecurKind::IFindLastIncIV ||
270+
Kind == RecurKind::FFindLastIncIV ||
271+
Kind == RecurKind::IFindLastDecIV ||
272+
Kind == RecurKind::FFindLastDecIV;
265273
}
266274

267275
/// Returns the type of the recurrence. This type can be narrower than the

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,10 @@ Value *createAnyOfOp(IRBuilderBase &Builder, Value *StartVal, RecurKind RK,
374374

375375
/// See RecurrenceDescriptor::isFindLastIVPattern for a description of the
376376
/// pattern we are trying to match. In this pattern, since the selected set of
377-
/// values forms an increasing sequence, we are selecting the maximum value from
378-
/// \p Left and \p Right.
379-
Value *createFindLastIVOp(IRBuilderBase &Builder, Value *Left, Value *Right);
377+
/// values forms an increasing/decreasing sequence, we are selecting the
378+
/// maximum/minimum value from \p Left and \p Right.
379+
Value *createFindLastIVOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
380+
Value *Right);
380381

381382
/// Returns a Min/Max operation corresponding to MinMaxRecurrenceKind.
382383
/// The Builder's fast-math-flags must be set to propagate the expected values.
@@ -409,7 +410,8 @@ Value *createAnyOfTargetReduction(IRBuilderBase &B, Value *Src,
409410
PHINode *OrigPhi);
410411

411412
/// Create a target reduction of the given vector \p Src for a reduction of the
412-
/// kind RecurKind::IFindLastIV or RecurKind::FFindLastIV. The reduction
413+
/// kinds RecurKind::IFindLastIncIV, RecurKind::FFindLastIncIV,
414+
/// RecurKind::IFindLastDecIV, and RecurKind::FFindLastDecIV. The reduction
413415
/// operation is described by \p Desc.
414416
Value *createFindLastIVTargetReduction(IRBuilderBase &B, Value *Src,
415417
const RecurrenceDescriptor &Desc);

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 64 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,10 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
5454
case RecurKind::UMin:
5555
case RecurKind::IAnyOf:
5656
case RecurKind::FAnyOf:
57-
case RecurKind::IFindLastIV:
58-
case RecurKind::FFindLastIV:
57+
case RecurKind::IFindLastIncIV:
58+
case RecurKind::FFindLastIncIV:
59+
case RecurKind::IFindLastDecIV:
60+
case RecurKind::FFindLastDecIV:
5961
return true;
6062
}
6163
return false;
@@ -664,6 +666,8 @@ RecurrenceDescriptor::isAnyOfPattern(Loop *Loop, PHINode *OrigPhi,
664666
: RecurKind::FAnyOf);
665667
}
666668

669+
enum class LoopInductionDirection { None, Increasing, Decreasing };
670+
667671
// We are looking for loops that do something like this:
668672
// int r = 0;
669673
// for (int i = 0; i < n; i++) {
@@ -711,47 +715,65 @@ RecurrenceDescriptor::isFindLastIVPattern(Loop *Loop, PHINode *OrigPhi,
711715
else
712716
return InstDesc(false, I);
713717

714-
auto IsIncreasingLoopInduction = [&SE, &Loop](Value *V) {
715-
if (!SE)
716-
return false;
717-
718+
auto GetLoopInduction = [&SE, &Loop](Value *V) {
718719
Type *Ty = V->getType();
719-
if (!SE->isSCEVable(Ty))
720-
return false;
720+
if (!SE || !SE->isSCEVable(Ty))
721+
return LoopInductionDirection::None;
721722

722723
auto *AR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(V));
723724
if (!AR)
724-
return false;
725-
726-
const SCEV *Step = AR->getStepRecurrence(*SE);
727-
if (!SE->isKnownPositive(Step))
728-
return false;
725+
return LoopInductionDirection::None;
729726

730727
const ConstantRange IVRange = SE->getSignedRange(AR);
731728
unsigned NumBits = Ty->getIntegerBitWidth();
732-
// Keep the minmum value of the recurrence type as the sentinel value.
733-
// The maximum acceptable range for the increasing induction variable,
734-
// called the valid range, will be defined as
735-
// [<sentinel value> + 1, SignedMin(<recurrence type>))
736-
// TODO: This range restriction can be lifted by adding an additional
737-
// virtual OR reduction.
738-
const APInt Sentinel = APInt::getSignedMinValue(NumBits);
739-
const ConstantRange ValidRange = ConstantRange::getNonEmpty(
740-
Sentinel + 1, APInt::getSignedMinValue(NumBits));
741-
LLVM_DEBUG(dbgs() << "LV: FindLastIV valid range is " << ValidRange
742-
<< ", and the signed range of " << *AR << " is "
743-
<< IVRange << "\n");
744-
return ValidRange.contains(IVRange);
729+
const SCEV *Step = AR->getStepRecurrence(*SE);
730+
731+
if (SE->isKnownPositive(Step)) {
732+
// For increasing IV, keep the minimum value of the recurrence type as the
733+
// sentinel value. The maximum acceptable range will be defined as
734+
// [<sentinel value> + 1, <sentinel value>)
735+
// TODO: This range restriction can be lifted by adding an additional
736+
// virtual OR reduction.
737+
const APInt Sentinel = APInt::getSignedMinValue(NumBits);
738+
const ConstantRange ValidRange =
739+
ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
740+
LLVM_DEBUG(dbgs() << "LV: FindLastIncIV valid range is " << ValidRange
741+
<< ", and the signed range of " << *AR << " is "
742+
<< IVRange << "\n");
743+
if (ValidRange.contains(IVRange))
744+
return LoopInductionDirection::Increasing;
745+
} else if (SE->isKnownNegative(Step)) {
746+
// For decreasing IV, keep the maximum value of the recurrence type as the
747+
// sentinel value. The maximum acceptable range will be defined as
748+
// [<sentinel value> + 1, <sentinel value>)
749+
const APInt Sentinel = APInt::getSignedMaxValue(NumBits);
750+
const ConstantRange ValidRange =
751+
ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
752+
LLVM_DEBUG(dbgs() << "LV: FindLastDecIV valid range is " << ValidRange
753+
<< ", and the signed range of " << *AR << " is "
754+
<< IVRange << "\n");
755+
if (ValidRange.contains(IVRange))
756+
return LoopInductionDirection::Decreasing;
757+
}
758+
return LoopInductionDirection::None;
745759
};
746760

747761
// We are looking for selects of the form:
748762
// select(cmp(), phi, loop_induction) or
749763
// select(cmp(), loop_induction, phi)
750-
if (!IsIncreasingLoopInduction(NonRdxPhi))
751-
return InstDesc(false, I);
752-
753-
return InstDesc(I, isa<ICmpInst>(I->getOperand(0)) ? RecurKind::IFindLastIV
754-
: RecurKind::FFindLastIV);
764+
switch (GetLoopInduction(NonRdxPhi)) {
765+
case LoopInductionDirection::None:
766+
break;
767+
case LoopInductionDirection::Increasing:
768+
return InstDesc(I, isa<ICmpInst>(I->getOperand(0))
769+
? RecurKind::IFindLastIncIV
770+
: RecurKind::FFindLastIncIV);
771+
case LoopInductionDirection::Decreasing:
772+
return InstDesc(I, isa<ICmpInst>(I->getOperand(0))
773+
? RecurKind::IFindLastDecIV
774+
: RecurKind::FFindLastDecIV);
775+
}
776+
return InstDesc(false, I);
755777
}
756778

757779
RecurrenceDescriptor::InstDesc
@@ -995,8 +1017,8 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
9951017
<< *Phi << "\n");
9961018
return true;
9971019
}
998-
if (AddReductionVar(Phi, RecurKind::IFindLastIV, TheLoop, FMF, RedDes, DB, AC,
999-
DT, SE)) {
1020+
if (AddReductionVar(Phi, RecurKind::IFindLastIncIV, TheLoop, FMF, RedDes, DB,
1021+
AC, DT, SE)) {
10001022
LLVM_DEBUG(dbgs() << "Found a FindLastIV reduction PHI." << *Phi << "\n");
10011023
return true;
10021024
}
@@ -1189,9 +1211,12 @@ Value *RecurrenceDescriptor::getRecurrenceIdentity(RecurKind K, Type *Tp,
11891211
case RecurKind::FAnyOf:
11901212
return getRecurrenceStartValue();
11911213
break;
1192-
case RecurKind::IFindLastIV:
1193-
case RecurKind::FFindLastIV:
1214+
case RecurKind::IFindLastIncIV:
1215+
case RecurKind::FFindLastIncIV:
11941216
return getRecurrenceIdentity(RecurKind::SMax, Tp, FMF);
1217+
case RecurKind::IFindLastDecIV:
1218+
case RecurKind::FFindLastDecIV:
1219+
return getRecurrenceIdentity(RecurKind::SMin, Tp, FMF);
11951220
default:
11961221
llvm_unreachable("Unknown recurrence kind");
11971222
}
@@ -1219,14 +1244,16 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
12191244
case RecurKind::UMax:
12201245
case RecurKind::UMin:
12211246
case RecurKind::IAnyOf:
1222-
case RecurKind::IFindLastIV:
1247+
case RecurKind::IFindLastIncIV:
1248+
case RecurKind::IFindLastDecIV:
12231249
return Instruction::ICmp;
12241250
case RecurKind::FMax:
12251251
case RecurKind::FMin:
12261252
case RecurKind::FMaximum:
12271253
case RecurKind::FMinimum:
12281254
case RecurKind::FAnyOf:
1229-
case RecurKind::FFindLastIV:
1255+
case RecurKind::FFindLastIncIV:
1256+
case RecurKind::FFindLastDecIV:
12301257
return Instruction::FCmp;
12311258
default:
12321259
llvm_unreachable("Unknown recurrence operation");

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -942,9 +942,18 @@ Value *llvm::createAnyOfOp(IRBuilderBase &Builder, Value *StartVal,
942942
return Builder.CreateSelect(Cmp, Left, Right, "rdx.select");
943943
}
944944

945-
Value *llvm::createFindLastIVOp(IRBuilderBase &Builder, Value *Left,
946-
Value *Right) {
947-
return createMinMaxOp(Builder, RecurKind::SMax, Left, Right);
945+
Value *llvm::createFindLastIVOp(IRBuilderBase &Builder, RecurKind RK,
946+
Value *Left, Value *Right) {
947+
switch (RK) {
948+
default:
949+
llvm_unreachable("Unexpected reduction kind");
950+
case RecurKind::IFindLastIncIV:
951+
case RecurKind::FFindLastIncIV:
952+
return createMinMaxOp(Builder, RecurKind::SMax, Left, Right);
953+
case RecurKind::IFindLastDecIV:
954+
case RecurKind::FFindLastDecIV:
955+
return createMinMaxOp(Builder, RecurKind::SMin, Left, Right);
956+
}
948957
}
949958

950959
Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
@@ -1069,10 +1078,16 @@ Value *llvm::createAnyOfTargetReduction(IRBuilderBase &Builder, Value *Src,
10691078

10701079
Value *llvm::createFindLastIVTargetReduction(IRBuilderBase &Builder, Value *Src,
10711080
const RecurrenceDescriptor &Desc) {
1072-
assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(
1073-
Desc.getRecurrenceKind()) &&
1074-
"Unexpected reduction kind");
1075-
return Builder.CreateIntMaxReduce(Src, true);
1081+
switch (Desc.getRecurrenceKind()) {
1082+
default:
1083+
llvm_unreachable("Unexpected reduction kind");
1084+
case RecurKind::IFindLastIncIV:
1085+
case RecurKind::FFindLastIncIV:
1086+
return Builder.CreateIntMaxReduce(Src, true);
1087+
case RecurKind::IFindLastDecIV:
1088+
case RecurKind::FFindLastDecIV:
1089+
return Builder.CreateIntMinReduce(Src, true);
1090+
}
10761091
}
10771092

10781093
Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, Value *Src,

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3902,7 +3902,8 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR,
39023902
ReducedPartRdx = createAnyOfOp(Builder, ReductionStartValue, RK,
39033903
ReducedPartRdx, RdxPart);
39043904
else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
3905-
ReducedPartRdx = createFindLastIVOp(Builder, ReducedPartRdx, RdxPart);
3905+
ReducedPartRdx =
3906+
createFindLastIVOp(Builder, RK, ReducedPartRdx, RdxPart);
39063907
else
39073908
ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart);
39083909
}

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14352,8 +14352,10 @@ class HorizontalReduction {
1435214352
case RecurKind::FMulAdd:
1435314353
case RecurKind::IAnyOf:
1435414354
case RecurKind::FAnyOf:
14355-
case RecurKind::IFindLastIV:
14356-
case RecurKind::FFindLastIV:
14355+
case RecurKind::IFindLastIncIV:
14356+
case RecurKind::FFindLastIncIV:
14357+
case RecurKind::IFindLastDecIV:
14358+
case RecurKind::FFindLastDecIV:
1435714359
case RecurKind::None:
1435814360
llvm_unreachable("Unexpected reduction kind for repeated scalar.");
1435914361
}
@@ -14443,8 +14445,10 @@ class HorizontalReduction {
1444314445
case RecurKind::FMulAdd:
1444414446
case RecurKind::IAnyOf:
1444514447
case RecurKind::FAnyOf:
14446-
case RecurKind::IFindLastIV:
14447-
case RecurKind::FFindLastIV:
14448+
case RecurKind::IFindLastIncIV:
14449+
case RecurKind::FFindLastIncIV:
14450+
case RecurKind::IFindLastDecIV:
14451+
case RecurKind::FFindLastDecIV:
1444814452
case RecurKind::None:
1444914453
llvm_unreachable("Unexpected reduction kind for reused scalars.");
1445014454
}

llvm/test/Transforms/LoopVectorize/if-reduction.ll

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -912,9 +912,9 @@ for.end: ; preds = %for.body, %entry
912912

913913
@table = constant [13 x i16] [i16 10, i16 35, i16 69, i16 147, i16 280, i16 472, i16 682, i16 1013, i16 1559, i16 2544, i16 4553, i16 6494, i16 10000], align 1
914914

915-
; CHECK-LABEL: @non_reduction_index(
916-
; CHECK-NOT: <4 x i16>
917-
define i16 @non_reduction_index(i16 noundef %val) {
915+
; CHECK-LABEL: @reduction_index(
916+
; CHECK: <4 x i16>
917+
define i16 @reduction_index(i16 noundef %val) {
918918
entry:
919919
br label %for.body
920920

@@ -936,9 +936,9 @@ for.body: ; preds = %entry, %for.body
936936

937937
@tablef = constant [13 x half] [half 10.0, half 35.0, half 69.0, half 147.0, half 280.0, half 472.0, half 682.0, half 1013.0, half 1559.0, half 2544.0, half 4556.0, half 6496.0, half 10000.0], align 1
938938

939-
; CHECK-LABEL: @non_reduction_index_half(
940-
; CHECK-NOT: <4 x half>
941-
define i16 @non_reduction_index_half(half noundef %val) {
939+
; CHECK-LABEL: @reduction_index_half(
940+
; CHECK: <4 x half>
941+
define i16 @reduction_index_half(half noundef %val) {
942942
entry:
943943
br label %for.body
944944

0 commit comments

Comments
 (0)