Skip to content

Commit 7e8bf11

Browse files
committed
[LoopVectorize] Vectorize select-cmp reduction pattern for increasing
integer induction variable Consider the following loop: int rdx = init; for (int i = 0; i < n; ++i) rdx = (a[i] > b[i]) ? i : rdx; We can vectorize this loop if `i` is an increasing induction variable. The final reduced value will be the maximum of `i` that the condition `a[i] > b[i]` is satisfied, or the start value `init`. This patch added new RecurKind enums - IFindLastIV and FFindLastIV.
1 parent ab9cd27 commit 7e8bf11

File tree

10 files changed

+2068
-40
lines changed

10 files changed

+2068
-40
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,16 @@ enum class RecurKind {
5252
FMulAdd, ///< Sum of float products with llvm.fmuladd(a * b + sum).
5353
IAnyOf, ///< Any_of reduction with select(icmp(),x,y) where one of (x,y) is
5454
///< loop invariant, and both x and y are integer type.
55-
FAnyOf ///< Any_of reduction with select(fcmp(),x,y) where one of (x,y) is
55+
FAnyOf, ///< Any_of reduction with select(fcmp(),x,y) where one of (x,y) is
5656
///< loop invariant, and both x and y are integer type.
57-
// TODO: Any_of reduction need not be restricted to integer type only.
57+
IFindLastIV, ///< FindLast reduction with select(icmp(),x,y) where one of
58+
///< (x,y) is increasing loop induction PHI, and both x and y are
59+
///< integer type.
60+
FFindLastIV ///< FindLast reduction with select(fcmp(),x,y) where one of (x,y)
61+
///< is increasing loop induction PHI, and both x and y are
62+
///< integer type.
63+
// TODO: Any_of and FindLast reduction need not be restricted to integer type
64+
// only.
5865
};
5966

6067
/// The RecurrenceDescriptor is used to identify recurrences variables in a
@@ -126,7 +133,7 @@ class RecurrenceDescriptor {
126133
/// the returned struct.
127134
static InstDesc isRecurrenceInstr(Loop *L, PHINode *Phi, Instruction *I,
128135
RecurKind Kind, InstDesc &Prev,
129-
FastMathFlags FuncFMF);
136+
FastMathFlags FuncFMF, ScalarEvolution *SE);
130137

131138
/// Returns true if instruction I has multiple uses in Insts
132139
static bool hasMultipleUsesOf(Instruction *I,
@@ -153,6 +160,16 @@ class RecurrenceDescriptor {
153160
static InstDesc isAnyOfPattern(Loop *Loop, PHINode *OrigPhi, Instruction *I,
154161
InstDesc &Prev);
155162

163+
/// Returns a struct describing whether the instruction is either a
164+
/// Select(ICmp(A, B), X, Y), or
165+
/// Select(FCmp(A, B), X, Y)
166+
/// where one of (X, Y) is an increasing loop induction variable, and the
167+
/// other is a PHI value.
168+
// TODO: FindLast does not need be restricted to increasing loop induction
169+
// variables.
170+
static InstDesc isFindLastIVPattern(Loop *Loop, PHINode *OrigPhi,
171+
Instruction *I, ScalarEvolution *SE);
172+
156173
/// Returns a struct describing if the instruction is a
157174
/// Select(FCmp(X, Y), (Z = X op PHINode), PHINode) instruction pattern.
158175
static InstDesc isConditionalRdxPattern(RecurKind Kind, Instruction *I);
@@ -241,6 +258,12 @@ class RecurrenceDescriptor {
241258
return Kind == RecurKind::IAnyOf || Kind == RecurKind::FAnyOf;
242259
}
243260

261+
/// Returns true if the recurrence kind is of the form
262+
/// select(cmp(),x,y) where one of (x,y) is increasing loop induction.
263+
static bool isFindLastIVRecurrenceKind(RecurKind Kind) {
264+
return Kind == RecurKind::IFindLastIV || Kind == RecurKind::FFindLastIV;
265+
}
266+
244267
/// Returns the type of the recurrence. This type can be narrower than the
245268
/// actual type of the Phi if the recurrence has been type-promoted.
246269
Type *getRecurrenceType() const { return RecurrenceType; }

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,12 @@ CmpInst::Predicate getMinMaxReductionPredicate(RecurKind RK);
372372
Value *createAnyOfOp(IRBuilderBase &Builder, Value *StartVal, RecurKind RK,
373373
Value *Left, Value *Right);
374374

375+
/// See RecurrenceDescriptor::isFindLastIVPattern for a description of the
376+
/// pattern we are trying to match. In this pattern, since the selected set of
377+
/// values forms an increasing sequence, we are selecting the maximum value from
378+
/// \p Left and \p Right.
379+
Value *createFindLastIVOp(IRBuilderBase &Builder, Value *Left, Value *Right);
380+
375381
/// Returns a Min/Max operation corresponding to MinMaxRecurrenceKind.
376382
/// The Builder's fast-math-flags must be set to propagate the expected values.
377383
Value *createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
@@ -402,6 +408,12 @@ Value *createAnyOfTargetReduction(IRBuilderBase &B, Value *Src,
402408
const RecurrenceDescriptor &Desc,
403409
PHINode *OrigPhi);
404410

411+
/// Create a target reduction of the given vector \p Src for a reduction of the
412+
/// kind RecurKind::IFindLastIV or RecurKind::FFindLastIV. The reduction
413+
/// operation is described by \p Desc.
414+
Value *createFindLastIVTargetReduction(IRBuilderBase &B, Value *Src,
415+
const RecurrenceDescriptor &Desc);
416+
405417
/// Create a generic target reduction using a recurrence descriptor \p Desc
406418
/// The target is queried to determine if intrinsics or shuffle sequences are
407419
/// required to implement the reduction.
@@ -415,6 +427,15 @@ Value *createOrderedReduction(IRBuilderBase &B,
415427
const RecurrenceDescriptor &Desc, Value *Src,
416428
Value *Start);
417429

430+
/// Returns a set of cmp and select instructions as shown below:
431+
/// Select(Cmp(NE, Rdx, Iden), Rdx, InitVal)
432+
/// where \p Rdx is a scalar value generated by target reduction, Iden is the
433+
/// sentinel value of the recurrence descriptor \p Desc, and InitVal is the
434+
/// start value of the recurrence descriptor \p Desc.
435+
Value *createSentinelValueHandling(IRBuilderBase &Builder,
436+
const RecurrenceDescriptor &Desc,
437+
Value *Rdx);
438+
418439
/// Get the intersection (logical and) of all of the potential IR flags
419440
/// of each scalar operation (VL) that will be converted into a vector (I).
420441
/// If OpValue is non-null, we only consider operations similar to OpValue

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 99 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
5454
case RecurKind::UMin:
5555
case RecurKind::IAnyOf:
5656
case RecurKind::FAnyOf:
57+
case RecurKind::IFindLastIV:
58+
case RecurKind::FFindLastIV:
5759
return true;
5860
}
5961
return false;
@@ -375,7 +377,7 @@ bool RecurrenceDescriptor::AddReductionVar(
375377
// type-promoted).
376378
if (Cur != Start) {
377379
ReduxDesc =
378-
isRecurrenceInstr(TheLoop, Phi, Cur, Kind, ReduxDesc, FuncFMF);
380+
isRecurrenceInstr(TheLoop, Phi, Cur, Kind, ReduxDesc, FuncFMF, SE);
379381
ExactFPMathInst = ExactFPMathInst == nullptr
380382
? ReduxDesc.getExactFPMathInst()
381383
: ExactFPMathInst;
@@ -662,6 +664,87 @@ RecurrenceDescriptor::isAnyOfPattern(Loop *Loop, PHINode *OrigPhi,
662664
: RecurKind::FAnyOf);
663665
}
664666

667+
// We are looking for loops that do something like this:
668+
// int r = 0;
669+
// for (int i = 0; i < n; i++) {
670+
// if (src[i] > 3)
671+
// r = i;
672+
// }
673+
// The reduction value (r) is derived from either the values of an increasing
674+
// induction variable (i) sequence, or from the start value (0).
675+
// The LLVM IR generated for such loops would be as follows:
676+
// for.body:
677+
// %r = phi i32 [ %spec.select, %for.body ], [ 0, %entry ]
678+
// %i = phi i32 [ %inc, %for.body ], [ 0, %entry ]
679+
// ...
680+
// %cmp = icmp sgt i32 %5, 3
681+
// %spec.select = select i1 %cmp, i32 %i, i32 %r
682+
// %inc = add nsw i32 %i, 1
683+
// ...
684+
// Since 'i' is an increasing induction variable, the reduction value after the
685+
// loop will be the maximum value of 'i' that the condition (src[i] > 3) is
686+
// satisfied, or the start value (0 in the example above). When the start value
687+
// of the increasing induction variable 'i' is greater than the minimum value of
688+
// the data type, we can use the minimum value of the data type as a sentinel
689+
// value to replace the start value. This allows us to perform a single
690+
// reduction max operation to obtain the final reduction result.
691+
// TODO: It is possible to solve the case where the start value is the minimum
692+
// value of the data type or a non-constant value by using mask and multiple
693+
// reduction operations.
694+
RecurrenceDescriptor::InstDesc
695+
RecurrenceDescriptor::isFindLastIVPattern(Loop *Loop, PHINode *OrigPhi,
696+
Instruction *I, ScalarEvolution *SE) {
697+
// Only match select with single use cmp condition.
698+
// TODO: Only handle single use for now.
699+
CmpInst::Predicate Pred;
700+
if (!match(I, m_Select(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), m_Value(),
701+
m_Value())))
702+
return InstDesc(false, I);
703+
704+
SelectInst *SI = cast<SelectInst>(I);
705+
Value *NonRdxPhi = nullptr;
706+
707+
if (OrigPhi == dyn_cast<PHINode>(SI->getTrueValue()))
708+
NonRdxPhi = SI->getFalseValue();
709+
else if (OrigPhi == dyn_cast<PHINode>(SI->getFalseValue()))
710+
NonRdxPhi = SI->getTrueValue();
711+
else
712+
return InstDesc(false, I);
713+
714+
auto IsIncreasingLoopInduction = [&SE, &Loop](Value *V) {
715+
auto *Phi = dyn_cast<PHINode>(V);
716+
if (!Phi)
717+
return false;
718+
719+
if (!SE)
720+
return false;
721+
722+
InductionDescriptor ID;
723+
if (!InductionDescriptor::isInductionPHI(Phi, Loop, SE, ID))
724+
return false;
725+
726+
const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(SE->getSCEV(Phi));
727+
if (!AR->hasNoSignedWrap())
728+
return false;
729+
730+
ConstantInt *IVStartValue = dyn_cast<ConstantInt>(ID.getStartValue());
731+
if (!IVStartValue || IVStartValue->isMinSignedValue())
732+
return false;
733+
734+
const SCEV *Step = ID.getStep();
735+
return SE->isKnownPositive(Step);
736+
};
737+
738+
// We are looking for selects of the form:
739+
// select(cmp(), phi, loop_induction) or
740+
// select(cmp(), loop_induction, phi)
741+
if (!IsIncreasingLoopInduction(NonRdxPhi))
742+
return InstDesc(false, I);
743+
744+
return InstDesc(I, isa<ICmpInst>(I->getOperand(0)) ? RecurKind::IFindLastIV
745+
: RecurKind::FFindLastIV);
746+
}
747+
665748
RecurrenceDescriptor::InstDesc
666749
RecurrenceDescriptor::isMinMaxPattern(Instruction *I, RecurKind Kind,
667750
const InstDesc &Prev) {
@@ -765,10 +848,9 @@ RecurrenceDescriptor::isConditionalRdxPattern(RecurKind Kind, Instruction *I) {
765848
return InstDesc(true, SI);
766849
}
767850

768-
RecurrenceDescriptor::InstDesc
769-
RecurrenceDescriptor::isRecurrenceInstr(Loop *L, PHINode *OrigPhi,
770-
Instruction *I, RecurKind Kind,
771-
InstDesc &Prev, FastMathFlags FuncFMF) {
851+
RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr(
852+
Loop *L, PHINode *OrigPhi, Instruction *I, RecurKind Kind, InstDesc &Prev,
853+
FastMathFlags FuncFMF, ScalarEvolution *SE) {
772854
assert(Prev.getRecKind() == RecurKind::None || Prev.getRecKind() == Kind);
773855
switch (I->getOpcode()) {
774856
default:
@@ -798,6 +880,8 @@ RecurrenceDescriptor::isRecurrenceInstr(Loop *L, PHINode *OrigPhi,
798880
if (Kind == RecurKind::FAdd || Kind == RecurKind::FMul ||
799881
Kind == RecurKind::Add || Kind == RecurKind::Mul)
800882
return isConditionalRdxPattern(Kind, I);
883+
if (isFindLastIVRecurrenceKind(Kind))
884+
return isFindLastIVPattern(L, OrigPhi, I, SE);
801885
[[fallthrough]];
802886
case Instruction::FCmp:
803887
case Instruction::ICmp:
@@ -902,6 +986,11 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
902986
<< *Phi << "\n");
903987
return true;
904988
}
989+
if (AddReductionVar(Phi, RecurKind::IFindLastIV, TheLoop, FMF, RedDes, DB, AC,
990+
DT, SE)) {
991+
LLVM_DEBUG(dbgs() << "Found a FindLastIV reduction PHI." << *Phi << "\n");
992+
return true;
993+
}
905994
if (AddReductionVar(Phi, RecurKind::FMul, TheLoop, FMF, RedDes, DB, AC, DT,
906995
SE)) {
907996
LLVM_DEBUG(dbgs() << "Found an FMult reduction PHI." << *Phi << "\n");
@@ -1091,6 +1180,9 @@ Value *RecurrenceDescriptor::getRecurrenceIdentity(RecurKind K, Type *Tp,
10911180
case RecurKind::FAnyOf:
10921181
return getRecurrenceStartValue();
10931182
break;
1183+
case RecurKind::IFindLastIV:
1184+
case RecurKind::FFindLastIV:
1185+
return getRecurrenceIdentity(RecurKind::SMax, Tp, FMF);
10941186
default:
10951187
llvm_unreachable("Unknown recurrence kind");
10961188
}
@@ -1118,12 +1210,14 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
11181210
case RecurKind::UMax:
11191211
case RecurKind::UMin:
11201212
case RecurKind::IAnyOf:
1213+
case RecurKind::IFindLastIV:
11211214
return Instruction::ICmp;
11221215
case RecurKind::FMax:
11231216
case RecurKind::FMin:
11241217
case RecurKind::FMaximum:
11251218
case RecurKind::FMinimum:
11261219
case RecurKind::FAnyOf:
1220+
case RecurKind::FFindLastIV:
11271221
return Instruction::FCmp;
11281222
default:
11291223
llvm_unreachable("Unknown recurrence operation");

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,11 @@ Value *llvm::createAnyOfOp(IRBuilderBase &Builder, Value *StartVal,
942942
return Builder.CreateSelect(Cmp, Left, Right, "rdx.select");
943943
}
944944

945+
Value *llvm::createFindLastIVOp(IRBuilderBase &Builder, Value *Left,
946+
Value *Right) {
947+
return createMinMaxOp(Builder, RecurKind::SMax, Left, Right);
948+
}
949+
945950
Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
946951
Value *Right) {
947952
Type *Ty = Left->getType();
@@ -1062,6 +1067,14 @@ Value *llvm::createAnyOfTargetReduction(IRBuilderBase &Builder, Value *Src,
10621067
return Builder.CreateSelect(Cmp, NewVal, InitVal, "rdx.select");
10631068
}
10641069

1070+
Value *llvm::createFindLastIVTargetReduction(IRBuilderBase &Builder, Value *Src,
1071+
const RecurrenceDescriptor &Desc) {
1072+
assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(
1073+
Desc.getRecurrenceKind()) &&
1074+
"Unexpected reduction kind");
1075+
return Builder.CreateIntMaxReduce(Src, true);
1076+
}
1077+
10651078
Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, Value *Src,
10661079
RecurKind RdxKind) {
10671080
auto *SrcVecEltTy = cast<VectorType>(Src->getType())->getElementType();
@@ -1115,6 +1128,8 @@ Value *llvm::createTargetReduction(IRBuilderBase &B,
11151128
RecurKind RK = Desc.getRecurrenceKind();
11161129
if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
11171130
return createAnyOfTargetReduction(B, Src, Desc, OrigPhi);
1131+
if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
1132+
return createFindLastIVTargetReduction(B, Src, Desc);
11181133

11191134
return createSimpleTargetReduction(B, Src, RK);
11201135
}
@@ -1131,6 +1146,16 @@ Value *llvm::createOrderedReduction(IRBuilderBase &B,
11311146
return B.CreateFAddReduce(Start, Src);
11321147
}
11331148

1149+
Value *llvm::createSentinelValueHandling(IRBuilderBase &Builder,
1150+
const RecurrenceDescriptor &Desc,
1151+
Value *Rdx) {
1152+
Value *InitVal = Desc.getRecurrenceStartValue();
1153+
Value *Iden = Desc.getRecurrenceIdentity(
1154+
Desc.getRecurrenceKind(), Rdx->getType(), Desc.getFastMathFlags());
1155+
Value *Cmp = Builder.CreateCmp(CmpInst::ICMP_NE, Rdx, Iden, "rdx.select.cmp");
1156+
return Builder.CreateSelect(Cmp, Rdx, InitVal, "rdx.select");
1157+
}
1158+
11341159
void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue,
11351160
bool IncludeWrapFlags) {
11361161
auto *VecOp = dyn_cast<Instruction>(I);

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3901,6 +3901,8 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR,
39013901
else if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
39023902
ReducedPartRdx = createAnyOfOp(Builder, ReductionStartValue, RK,
39033903
ReducedPartRdx, RdxPart);
3904+
else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
3905+
ReducedPartRdx = createFindLastIVOp(Builder, ReducedPartRdx, RdxPart);
39043906
else
39053907
ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart);
39063908
}
@@ -3919,6 +3921,10 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR,
39193921
: Builder.CreateZExt(ReducedPartRdx, PhiTy);
39203922
}
39213923

3924+
if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
3925+
ReducedPartRdx =
3926+
createSentinelValueHandling(Builder, RdxDesc, ReducedPartRdx);
3927+
39223928
PHINode *ResumePhi =
39233929
dyn_cast<PHINode>(PhiR->getStartValue()->getUnderlyingValue());
39243930

@@ -5822,8 +5828,9 @@ LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF,
58225828
HasReductions &&
58235829
any_of(Legal->getReductionVars(), [&](auto &Reduction) -> bool {
58245830
const RecurrenceDescriptor &RdxDesc = Reduction.second;
5825-
return RecurrenceDescriptor::isAnyOfRecurrenceKind(
5826-
RdxDesc.getRecurrenceKind());
5831+
RecurKind RK = RdxDesc.getRecurrenceKind();
5832+
return RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
5833+
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK);
58275834
});
58285835
if (HasSelectCmpReductions) {
58295836
LLVM_DEBUG(dbgs() << "LV: Not interleaving select-cmp reductions.\n");
@@ -8973,8 +8980,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
89738980
for (VPReductionPHIRecipe *PhiR : InLoopReductionPhis) {
89748981
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
89758982
RecurKind Kind = RdxDesc.getRecurrenceKind();
8976-
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
8977-
"AnyOf reductions are not allowed for in-loop reductions");
8983+
assert(
8984+
(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
8985+
!RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind)) &&
8986+
"AnyOf and FindLast reductions are not allowed for in-loop reductions");
89788987

89798988
// Collect the chain of "link" recipes for the reduction starting at PhiR.
89808989
SetVector<VPRecipeBase *> Worklist;

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14352,6 +14352,8 @@ class HorizontalReduction {
1435214352
case RecurKind::FMulAdd:
1435314353
case RecurKind::IAnyOf:
1435414354
case RecurKind::FAnyOf:
14355+
case RecurKind::IFindLastIV:
14356+
case RecurKind::FFindLastIV:
1435514357
case RecurKind::None:
1435614358
llvm_unreachable("Unexpected reduction kind for repeated scalar.");
1435714359
}
@@ -14441,6 +14443,8 @@ class HorizontalReduction {
1444114443
case RecurKind::FMulAdd:
1444214444
case RecurKind::IAnyOf:
1444314445
case RecurKind::FAnyOf:
14446+
case RecurKind::IFindLastIV:
14447+
case RecurKind::FFindLastIV:
1444414448
case RecurKind::None:
1444514449
llvm_unreachable("Unexpected reduction kind for reused scalars.");
1444614450
}

0 commit comments

Comments
 (0)