Skip to content

Commit 333616e

Browse files
committed
[LoopVectorize] Vectorize select-cmp reduction pattern for increasing
integer induction variable Consider the following loop: int rdx = init; for (int i = 0; i < n; ++i) rdx = (a[i] > b[i]) ? i : rdx; We can vectorize this loop if `i` is an increasing induction variable. The final reduced value will be the maximum of `i` that the condition `a[i] > b[i]` is satisfied, or the start value `init`. This patch added new RecurKind enums - IFindLastIV and FFindLastIV.
1 parent b933c84 commit 333616e

File tree

10 files changed

+2069
-41
lines changed

10 files changed

+2069
-41
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,16 @@ enum class RecurKind {
5151
FMulAdd, ///< Sum of float products with llvm.fmuladd(a * b + sum).
5252
IAnyOf, ///< Any_of reduction with select(icmp(),x,y) where one of (x,y) is
5353
///< loop invariant, and both x and y are integer type.
54-
FAnyOf ///< Any_of reduction with select(fcmp(),x,y) where one of (x,y) is
54+
FAnyOf, ///< Any_of reduction with select(fcmp(),x,y) where one of (x,y) is
5555
///< loop invariant, and both x and y are integer type.
56-
// TODO: Any_of reduction need not be restricted to integer type only.
56+
IFindLastIV, ///< FindLast reduction with select(icmp(),x,y) where one of
57+
///< (x,y) is increasing loop induction PHI, and both x and y are
58+
///< integer type.
59+
FFindLastIV ///< FindLast reduction with select(fcmp(),x,y) where one of (x,y)
60+
///< is increasing loop induction PHI, and both x and y are
61+
///< integer type.
62+
// TODO: Any_of and FindLast reduction need not be restricted to integer type
63+
// only.
5764
};
5865

5966
/// The RecurrenceDescriptor is used to identify recurrences variables in a
@@ -125,7 +132,7 @@ class RecurrenceDescriptor {
125132
/// the returned struct.
126133
static InstDesc isRecurrenceInstr(Loop *L, PHINode *Phi, Instruction *I,
127134
RecurKind Kind, InstDesc &Prev,
128-
FastMathFlags FuncFMF);
135+
FastMathFlags FuncFMF, ScalarEvolution *SE);
129136

130137
/// Returns true if instruction I has multiple uses in Insts
131138
static bool hasMultipleUsesOf(Instruction *I,
@@ -152,6 +159,16 @@ class RecurrenceDescriptor {
152159
static InstDesc isAnyOfPattern(Loop *Loop, PHINode *OrigPhi, Instruction *I,
153160
InstDesc &Prev);
154161

162+
/// Returns a struct describing whether the instruction is either a
163+
/// Select(ICmp(A, B), X, Y), or
164+
/// Select(FCmp(A, B), X, Y)
165+
/// where one of (X, Y) is an increasing loop induction variable, and the
166+
/// other is a PHI value.
167+
// TODO: FindLast does not need be restricted to increasing loop induction
168+
// variables.
169+
static InstDesc isFindLastIVPattern(Loop *Loop, PHINode *OrigPhi,
170+
Instruction *I, ScalarEvolution *SE);
171+
155172
/// Returns a struct describing if the instruction is a
156173
/// Select(FCmp(X, Y), (Z = X op PHINode), PHINode) instruction pattern.
157174
static InstDesc isConditionalRdxPattern(RecurKind Kind, Instruction *I);
@@ -240,6 +257,12 @@ class RecurrenceDescriptor {
240257
return Kind == RecurKind::IAnyOf || Kind == RecurKind::FAnyOf;
241258
}
242259

260+
/// Returns true if the recurrence kind is of the form
261+
/// select(cmp(),x,y) where one of (x,y) is increasing loop induction.
262+
static bool isFindLastIVRecurrenceKind(RecurKind Kind) {
263+
return Kind == RecurKind::IFindLastIV || Kind == RecurKind::FFindLastIV;
264+
}
265+
243266
/// Returns the type of the recurrence. This type can be narrower than the
244267
/// actual type of the Phi if the recurrence has been type-promoted.
245268
Type *getRecurrenceType() const { return RecurrenceType; }

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,12 @@ CmpInst::Predicate getMinMaxReductionPredicate(RecurKind RK);
381381
Value *createAnyOfOp(IRBuilderBase &Builder, Value *StartVal, RecurKind RK,
382382
Value *Left, Value *Right);
383383

384+
/// See RecurrenceDescriptor::isFindLastIVPattern for a description of the
385+
/// pattern we are trying to match. In this pattern, since the selected set of
386+
/// values forms an increasing sequence, we are selecting the maximum value from
387+
/// \p Left and \p Right.
388+
Value *createFindLastIVOp(IRBuilderBase &Builder, Value *Left, Value *Right);
389+
384390
/// Returns a Min/Max operation corresponding to MinMaxRecurrenceKind.
385391
/// The Builder's fast-math-flags must be set to propagate the expected values.
386392
Value *createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
@@ -411,6 +417,12 @@ Value *createAnyOfTargetReduction(IRBuilderBase &B, Value *Src,
411417
const RecurrenceDescriptor &Desc,
412418
PHINode *OrigPhi);
413419

420+
/// Create a target reduction of the given vector \p Src for a reduction of the
421+
/// kind RecurKind::IFindLastIV or RecurKind::FFindLastIV. The reduction
422+
/// operation is described by \p Desc.
423+
Value *createFindLastIVTargetReduction(IRBuilderBase &B, Value *Src,
424+
const RecurrenceDescriptor &Desc);
425+
414426
/// Create a generic target reduction using a recurrence descriptor \p Desc
415427
/// The target is queried to determine if intrinsics or shuffle sequences are
416428
/// required to implement the reduction.
@@ -424,6 +436,15 @@ Value *createOrderedReduction(IRBuilderBase &B,
424436
const RecurrenceDescriptor &Desc, Value *Src,
425437
Value *Start);
426438

439+
/// Returns a set of cmp and select instructions as shown below:
440+
/// Select(Cmp(NE, Rdx, Iden), Rdx, InitVal)
441+
/// where \p Rdx is a scalar value generated by target reduction, Iden is the
442+
/// sentinel value of the recurrence descriptor \p Desc, and InitVal is the
443+
/// start value of the recurrence descriptor \p Desc.
444+
Value *createSentinelValueHandling(IRBuilderBase &Builder,
445+
const RecurrenceDescriptor &Desc,
446+
Value *Rdx);
447+
427448
/// Get the intersection (logical and) of all of the potential IR flags
428449
/// of each scalar operation (VL) that will be converted into a vector (I).
429450
/// If OpValue is non-null, we only consider operations similar to OpValue

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 99 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
5252
case RecurKind::UMin:
5353
case RecurKind::IAnyOf:
5454
case RecurKind::FAnyOf:
55+
case RecurKind::IFindLastIV:
56+
case RecurKind::FFindLastIV:
5557
return true;
5658
}
5759
return false;
@@ -373,7 +375,7 @@ bool RecurrenceDescriptor::AddReductionVar(
373375
// type-promoted).
374376
if (Cur != Start) {
375377
ReduxDesc =
376-
isRecurrenceInstr(TheLoop, Phi, Cur, Kind, ReduxDesc, FuncFMF);
378+
isRecurrenceInstr(TheLoop, Phi, Cur, Kind, ReduxDesc, FuncFMF, SE);
377379
ExactFPMathInst = ExactFPMathInst == nullptr
378380
? ReduxDesc.getExactFPMathInst()
379381
: ExactFPMathInst;
@@ -660,6 +662,87 @@ RecurrenceDescriptor::isAnyOfPattern(Loop *Loop, PHINode *OrigPhi,
660662
: RecurKind::FAnyOf);
661663
}
662664

665+
// We are looking for loops that do something like this:
666+
// int r = 0;
667+
// for (int i = 0; i < n; i++) {
668+
// if (src[i] > 3)
669+
// r = i;
670+
// }
671+
// The reduction value (r) is derived from either the values of an increasing
672+
// induction variable (i) sequence, or from the start value (0).
673+
// The LLVM IR generated for such loops would be as follows:
674+
// for.body:
675+
// %r = phi i32 [ %spec.select, %for.body ], [ 0, %entry ]
676+
// %i = phi i32 [ %inc, %for.body ], [ 0, %entry ]
677+
// ...
678+
// %cmp = icmp sgt i32 %5, 3
679+
// %spec.select = select i1 %cmp, i32 %i, i32 %r
680+
// %inc = add nsw i32 %i, 1
681+
// ...
682+
// Since 'i' is an increasing induction variable, the reduction value after the
683+
// loop will be the maximum value of 'i' that the condition (src[i] > 3) is
684+
// satisfied, or the start value (0 in the example above). When the start value
685+
// of the increasing induction variable 'i' is greater than the minimum value of
686+
// the data type, we can use the minimum value of the data type as a sentinel
687+
// value to replace the start value. This allows us to perform a single
688+
// reduction max operation to obtain the final reduction result.
689+
// TODO: It is possible to solve the case where the start value is the minimum
690+
// value of the data type or a non-constant value by using mask and multiple
691+
// reduction operations.
692+
RecurrenceDescriptor::InstDesc
693+
RecurrenceDescriptor::isFindLastIVPattern(Loop *Loop, PHINode *OrigPhi,
694+
Instruction *I, ScalarEvolution *SE) {
695+
// Only match select with single use cmp condition.
696+
// TODO: Only handle single use for now.
697+
CmpInst::Predicate Pred;
698+
if (!match(I, m_Select(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), m_Value(),
699+
m_Value())))
700+
return InstDesc(false, I);
701+
702+
SelectInst *SI = cast<SelectInst>(I);
703+
Value *NonRdxPhi = nullptr;
704+
705+
if (OrigPhi == dyn_cast<PHINode>(SI->getTrueValue()))
706+
NonRdxPhi = SI->getFalseValue();
707+
else if (OrigPhi == dyn_cast<PHINode>(SI->getFalseValue()))
708+
NonRdxPhi = SI->getTrueValue();
709+
else
710+
return InstDesc(false, I);
711+
712+
auto IsIncreasingLoopInduction = [&SE, &Loop](Value *V) {
713+
auto *Phi = dyn_cast<PHINode>(V);
714+
if (!Phi)
715+
return false;
716+
717+
if (!SE)
718+
return false;
719+
720+
InductionDescriptor ID;
721+
if (!InductionDescriptor::isInductionPHI(Phi, Loop, SE, ID))
722+
return false;
723+
724+
const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(SE->getSCEV(Phi));
725+
if (!AR->hasNoSignedWrap())
726+
return false;
727+
728+
ConstantInt *IVStartValue = dyn_cast<ConstantInt>(ID.getStartValue());
729+
if (!IVStartValue || IVStartValue->isMinSignedValue())
730+
return false;
731+
732+
const SCEV *Step = ID.getStep();
733+
return SE->isKnownPositive(Step);
734+
};
735+
736+
// We are looking for selects of the form:
737+
// select(cmp(), phi, loop_induction) or
738+
// select(cmp(), loop_induction, phi)
739+
if (!IsIncreasingLoopInduction(NonRdxPhi))
740+
return InstDesc(false, I);
741+
742+
return InstDesc(I, isa<ICmpInst>(I->getOperand(0)) ? RecurKind::IFindLastIV
743+
: RecurKind::FFindLastIV);
744+
}
745+
663746
RecurrenceDescriptor::InstDesc
664747
RecurrenceDescriptor::isMinMaxPattern(Instruction *I, RecurKind Kind,
665748
const InstDesc &Prev) {
@@ -763,10 +846,9 @@ RecurrenceDescriptor::isConditionalRdxPattern(RecurKind Kind, Instruction *I) {
763846
return InstDesc(true, SI);
764847
}
765848

766-
RecurrenceDescriptor::InstDesc
767-
RecurrenceDescriptor::isRecurrenceInstr(Loop *L, PHINode *OrigPhi,
768-
Instruction *I, RecurKind Kind,
769-
InstDesc &Prev, FastMathFlags FuncFMF) {
849+
RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr(
850+
Loop *L, PHINode *OrigPhi, Instruction *I, RecurKind Kind, InstDesc &Prev,
851+
FastMathFlags FuncFMF, ScalarEvolution *SE) {
770852
assert(Prev.getRecKind() == RecurKind::None || Prev.getRecKind() == Kind);
771853
switch (I->getOpcode()) {
772854
default:
@@ -796,6 +878,8 @@ RecurrenceDescriptor::isRecurrenceInstr(Loop *L, PHINode *OrigPhi,
796878
if (Kind == RecurKind::FAdd || Kind == RecurKind::FMul ||
797879
Kind == RecurKind::Add || Kind == RecurKind::Mul)
798880
return isConditionalRdxPattern(Kind, I);
881+
if (isFindLastIVRecurrenceKind(Kind))
882+
return isFindLastIVPattern(L, OrigPhi, I, SE);
799883
[[fallthrough]];
800884
case Instruction::FCmp:
801885
case Instruction::ICmp:
@@ -900,6 +984,11 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
900984
<< *Phi << "\n");
901985
return true;
902986
}
987+
if (AddReductionVar(Phi, RecurKind::IFindLastIV, TheLoop, FMF, RedDes, DB, AC,
988+
DT, SE)) {
989+
LLVM_DEBUG(dbgs() << "Found a FindLastIV reduction PHI." << *Phi << "\n");
990+
return true;
991+
}
903992
if (AddReductionVar(Phi, RecurKind::FMul, TheLoop, FMF, RedDes, DB, AC, DT,
904993
SE)) {
905994
LLVM_DEBUG(dbgs() << "Found an FMult reduction PHI." << *Phi << "\n");
@@ -1089,6 +1178,9 @@ Value *RecurrenceDescriptor::getRecurrenceIdentity(RecurKind K, Type *Tp,
10891178
case RecurKind::FAnyOf:
10901179
return getRecurrenceStartValue();
10911180
break;
1181+
case RecurKind::IFindLastIV:
1182+
case RecurKind::FFindLastIV:
1183+
return getRecurrenceIdentity(RecurKind::SMax, Tp, FMF);
10921184
default:
10931185
llvm_unreachable("Unknown recurrence kind");
10941186
}
@@ -1116,12 +1208,14 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
11161208
case RecurKind::UMax:
11171209
case RecurKind::UMin:
11181210
case RecurKind::IAnyOf:
1211+
case RecurKind::IFindLastIV:
11191212
return Instruction::ICmp;
11201213
case RecurKind::FMax:
11211214
case RecurKind::FMin:
11221215
case RecurKind::FMaximum:
11231216
case RecurKind::FMinimum:
11241217
case RecurKind::FAnyOf:
1218+
case RecurKind::FFindLastIV:
11251219
return Instruction::FCmp;
11261220
default:
11271221
llvm_unreachable("Unknown recurrence operation");

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,11 @@ Value *llvm::createAnyOfOp(IRBuilderBase &Builder, Value *StartVal,
10431043
return Builder.CreateSelect(Cmp, Left, Right, "rdx.select");
10441044
}
10451045

1046+
Value *llvm::createFindLastIVOp(IRBuilderBase &Builder, Value *Left,
1047+
Value *Right) {
1048+
return createMinMaxOp(Builder, RecurKind::SMax, Left, Right);
1049+
}
1050+
10461051
Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
10471052
Value *Right) {
10481053
Type *Ty = Left->getType();
@@ -1163,6 +1168,14 @@ Value *llvm::createAnyOfTargetReduction(IRBuilderBase &Builder, Value *Src,
11631168
return Builder.CreateSelect(Cmp, NewVal, InitVal, "rdx.select");
11641169
}
11651170

1171+
Value *llvm::createFindLastIVTargetReduction(IRBuilderBase &Builder, Value *Src,
1172+
const RecurrenceDescriptor &Desc) {
1173+
assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(
1174+
Desc.getRecurrenceKind()) &&
1175+
"Unexpected reduction kind");
1176+
return Builder.CreateIntMaxReduce(Src, true);
1177+
}
1178+
11661179
Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, Value *Src,
11671180
RecurKind RdxKind) {
11681181
auto *SrcVecEltTy = cast<VectorType>(Src->getType())->getElementType();
@@ -1216,6 +1229,8 @@ Value *llvm::createTargetReduction(IRBuilderBase &B,
12161229
RecurKind RK = Desc.getRecurrenceKind();
12171230
if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
12181231
return createAnyOfTargetReduction(B, Src, Desc, OrigPhi);
1232+
if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
1233+
return createFindLastIVTargetReduction(B, Src, Desc);
12191234

12201235
return createSimpleTargetReduction(B, Src, RK);
12211236
}
@@ -1232,6 +1247,16 @@ Value *llvm::createOrderedReduction(IRBuilderBase &B,
12321247
return B.CreateFAddReduce(Start, Src);
12331248
}
12341249

1250+
Value *llvm::createSentinelValueHandling(IRBuilderBase &Builder,
1251+
const RecurrenceDescriptor &Desc,
1252+
Value *Rdx) {
1253+
Value *InitVal = Desc.getRecurrenceStartValue();
1254+
Value *Iden = Desc.getRecurrenceIdentity(
1255+
Desc.getRecurrenceKind(), Rdx->getType(), Desc.getFastMathFlags());
1256+
Value *Cmp = Builder.CreateCmp(CmpInst::ICMP_NE, Rdx, Iden, "rdx.select.cmp");
1257+
return Builder.CreateSelect(Cmp, Rdx, InitVal, "rdx.select");
1258+
}
1259+
12351260
void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue,
12361261
bool IncludeWrapFlags) {
12371262
auto *VecOp = dyn_cast<Instruction>(I);

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5510,8 +5510,9 @@ LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF,
55105510
HasReductions &&
55115511
any_of(Legal->getReductionVars(), [&](auto &Reduction) -> bool {
55125512
const RecurrenceDescriptor &RdxDesc = Reduction.second;
5513-
return RecurrenceDescriptor::isAnyOfRecurrenceKind(
5514-
RdxDesc.getRecurrenceKind());
5513+
RecurKind RK = RdxDesc.getRecurrenceKind();
5514+
return RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
5515+
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK);
55155516
});
55165517
if (HasSelectCmpReductions) {
55175518
LLVM_DEBUG(dbgs() << "LV: Not interleaving select-cmp reductions.\n");
@@ -8946,8 +8947,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
89468947

89478948
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
89488949
RecurKind Kind = RdxDesc.getRecurrenceKind();
8949-
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
8950-
"AnyOf reductions are not allowed for in-loop reductions");
8950+
assert(
8951+
(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
8952+
!RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind)) &&
8953+
"AnyOf and FindLast reductions are not allowed for in-loop reductions");
89518954

89528955
// Collect the chain of "link" recipes for the reduction starting at PhiR.
89538956
SetVector<VPSingleDefRecipe *> Worklist;

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17008,6 +17008,8 @@ class HorizontalReduction {
1700817008
case RecurKind::FMulAdd:
1700917009
case RecurKind::IAnyOf:
1701017010
case RecurKind::FAnyOf:
17011+
case RecurKind::IFindLastIV:
17012+
case RecurKind::FFindLastIV:
1701117013
case RecurKind::None:
1701217014
llvm_unreachable("Unexpected reduction kind for repeated scalar.");
1701317015
}
@@ -17110,6 +17112,8 @@ class HorizontalReduction {
1711017112
case RecurKind::FMulAdd:
1711117113
case RecurKind::IAnyOf:
1711217114
case RecurKind::FAnyOf:
17115+
case RecurKind::IFindLastIV:
17116+
case RecurKind::FFindLastIV:
1711317117
case RecurKind::None:
1711417118
llvm_unreachable("Unexpected reduction kind for reused scalars.");
1711517119
}

0 commit comments

Comments
 (0)