Skip to content

Commit 1e8ef2e

Browse files
committed
[LV] Add support for cmp reductions with decreasing IVs using SMin.
Similar to FindLastIV, add FindFirstIV to support select (icmp(), x, y) reductions where one of x or y is a decreasing induction. This is done via a new recurrence kind FindFirstIVSMin, which selects the first value from the reduction vector using smin instead of the last value (FindLastIV). It uses signed max as sentinel value. The
1 parent 5ffdd94 commit 1e8ef2e

File tree

11 files changed

+1140
-148
lines changed

11 files changed

+1140
-148
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ enum class RecurKind {
5454
FMulAdd, ///< Sum of float products with llvm.fmuladd(a * b + sum).
5555
AnyOf, ///< AnyOf reduction with select(cmp(),x,y) where one of (x,y) is
5656
///< loop invariant, and both x and y are integer type.
57+
FindFirstIVSMin, /// FindFirst reduction with select(icmp(),x,y) where one of
58+
///< (x,y) is a decreasing loop induction, and both x and y
59+
///< are integer type, producing a SMin reduction.
5760
FindLastIVSMax, ///< FindLast reduction with select(cmp(),x,y) where one of
5861
///< (x,y) is increasing loop induction, and both x and y
5962
///< are integer type, producing a SMax reduction.
@@ -165,13 +168,13 @@ class RecurrenceDescriptor {
165168
/// Returns a struct describing whether the instruction is either a
166169
/// Select(ICmp(A, B), X, Y), or
167170
/// Select(FCmp(A, B), X, Y)
168-
/// where one of (X, Y) is an increasing loop induction variable, and the
169-
/// other is a PHI value.
171+
/// where one of (X, Y) is an increasing (FindLast) or decreasing (FindFirst)
172+
/// loop induction variable, and the other is a PHI value.
170173
// TODO: Support non-monotonic variable. FindLast does not need be restricted
171174
// to increasing loop induction variables.
172-
LLVM_ABI static InstDesc isFindLastIVPattern(Loop *TheLoop, PHINode *OrigPhi,
173-
Instruction *I,
174-
ScalarEvolution &SE);
175+
LLVM_ABI static InstDesc isFindIVPattern(RecurKind Kind, Loop *TheLoop,
176+
PHINode *OrigPhi, Instruction *I,
177+
ScalarEvolution &SE);
175178

176179
/// Returns a struct describing if the instruction is a
177180
/// Select(FCmp(X, Y), (Z = X op PHINode), PHINode) instruction pattern.
@@ -259,6 +262,12 @@ class RecurrenceDescriptor {
259262
return Kind == RecurKind::AnyOf;
260263
}
261264

265+
/// Returns true if the recurrence kind is of the form
266+
/// select(cmp(),x,y) where one of (x,y) is decreasing loop induction.
267+
static bool isFindFirstIVRecurrenceKind(RecurKind Kind) {
268+
return Kind == RecurKind::FindFirstIVSMin;
269+
}
270+
262271
/// Returns true if the recurrence kind is of the form
263272
/// select(cmp(),x,y) where one of (x,y) is increasing loop induction.
264273
static bool isFindLastIVRecurrenceKind(RecurKind Kind) {
@@ -269,22 +278,35 @@ class RecurrenceDescriptor {
269278
/// Returns true if recurrece kind is a signed redux kind.
270279
static bool isSignedRecurrenceKind(RecurKind Kind) {
271280
return Kind == RecurKind::SMax || Kind == RecurKind::SMin ||
281+
Kind == RecurKind::FindFirstIVSMin ||
272282
Kind == RecurKind::FindLastIVSMax;
273283
}
274284

285+
/// Returns true if the recurrence kind is of the form
286+
/// select(cmp(),x,y) where one of (x,y) is an increasing or decreasing loop
287+
/// induction.
288+
static bool isFindIVRecurrenceKind(RecurKind Kind) {
289+
return isFindFirstIVRecurrenceKind(Kind) ||
290+
isFindLastIVRecurrenceKind(Kind);
291+
}
292+
275293
/// Returns the type of the recurrence. This type can be narrower than the
276294
/// actual type of the Phi if the recurrence has been type-promoted.
277295
Type *getRecurrenceType() const { return RecurrenceType; }
278296

279-
/// Returns the sentinel value for FindLastIV recurrences to replace the start
280-
/// value.
297+
/// Returns the sentinel value for FindFirstIV &FindLastIV recurrences to
298+
/// replace the start value.
281299
Value *getSentinelValue() const {
282-
assert(isFindLastIVRecurrenceKind(Kind) && "Unexpected recurrence kind");
283300
Type *Ty = StartValue->getType();
284301
unsigned BW = Ty->getIntegerBitWidth();
302+
if (isFindLastIVRecurrenceKind(Kind)) {
303+
return ConstantInt::get(Ty, isSignedRecurrenceKind(Kind)
304+
? APInt::getSignedMinValue(BW)
305+
: APInt::getMinValue(BW));
306+
}
285307
return ConstantInt::get(Ty, isSignedRecurrenceKind(Kind)
286-
? APInt::getSignedMinValue(BW)
287-
: APInt::getMinValue(BW));
308+
? APInt::getSignedMaxValue(BW)
309+
: APInt::getMaxValue(BW));
288310
}
289311

290312
/// Returns a reference to the instructions used for type-promoting the

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
5050
case RecurKind::UMax:
5151
case RecurKind::UMin:
5252
case RecurKind::AnyOf:
53+
case RecurKind::FindFirstIVSMin:
5354
case RecurKind::FindLastIVSMax:
5455
case RecurKind::FindLastIVUMax:
5556
return true;
@@ -684,8 +685,9 @@ RecurrenceDescriptor::isAnyOfPattern(Loop *Loop, PHINode *OrigPhi,
684685
// value of the data type or a non-constant value by using mask and multiple
685686
// reduction operations.
686687
RecurrenceDescriptor::InstDesc
687-
RecurrenceDescriptor::isFindLastIVPattern(Loop *TheLoop, PHINode *OrigPhi,
688-
Instruction *I, ScalarEvolution &SE) {
688+
RecurrenceDescriptor::isFindIVPattern(RecurKind Kind, Loop *TheLoop,
689+
PHINode *OrigPhi, Instruction *I,
690+
ScalarEvolution &SE) {
689691
// TODO: Support the vectorization of FindLastIV when the reduction phi is
690692
// used by more than one select instruction. This vectorization is only
691693
// performed when the SCEV of each increasing induction variable used by the
@@ -713,36 +715,68 @@ RecurrenceDescriptor::isFindLastIVPattern(Loop *TheLoop, PHINode *OrigPhi,
713715
return std::nullopt;
714716

715717
const SCEV *Step = AR->getStepRecurrence(SE);
716-
if (!SE.isKnownPositive(Step))
718+
719+
if (isFindFirstIVRecurrenceKind(Kind)) {
720+
if (!SE.isKnownNegative(Step))
721+
return std::nullopt;
722+
} else if (!SE.isKnownPositive(Step))
717723
return std::nullopt;
718724

719725
// Keep the minimum value of the recurrence type as the sentinel value.
720726
// The maximum acceptable range for the increasing induction variable,
721727
// called the valid range, will be defined as
728+
729+
const ConstantRange IVRange = SE.getSignedRange(AR);
730+
// Keep the minimum (FindLast) or maximum (FindFirst) value of the
731+
// recurrence type as the sentinel value. The maximum acceptable range for
732+
// the induction variable, called the valid range, will be defined as
722733
// [<sentinel value> + 1, <sentinel value>)
723-
// where <sentinel value> is [Signed|Unsigned]Min(<recurrence type>)
734+
// where <sentinel value> is [Signed|Unsigned]Min(<recurrence type>) for
735+
// FindLastIV or [Signed|Unsigned]Max(<recurrence type>) for FindFirstIV.
724736
// TODO: This range restriction can be lifted by adding an additional
725737
// virtual OR reduction.
726738
auto CheckRange = [&](bool IsSigned) {
727739
const ConstantRange IVRange =
728740
IsSigned ? SE.getSignedRange(AR) : SE.getUnsignedRange(AR);
729741
unsigned NumBits = Ty->getIntegerBitWidth();
730-
const APInt Sentinel = IsSigned ? APInt::getSignedMinValue(NumBits)
731-
: APInt::getMinValue(NumBits);
732-
const ConstantRange ValidRange =
733-
ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
734-
LLVM_DEBUG(dbgs() << "LV: FindLastIV valid range is " << ValidRange
742+
ConstantRange ValidRange = ConstantRange::getEmpty(NumBits);
743+
if (isFindLastIVRecurrenceKind(Kind)) {
744+
APInt Sentinel = IsSigned ? APInt::getSignedMinValue(NumBits)
745+
: APInt::getMinValue(NumBits);
746+
ValidRange = ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
747+
} else {
748+
assert(isFindFirstIVRecurrenceKind(Kind) &&
749+
"Kind must either be a FindLastIV or FindFirstIV");
750+
assert(IsSigned &&
751+
"only FindFirstIV with SMax is supported at the moment");
752+
ValidRange =
753+
ConstantRange::getNonEmpty(APInt::getSignedMinValue(NumBits),
754+
APInt::getSignedMaxValue(NumBits) - 1);
755+
}
756+
757+
LLVM_DEBUG(dbgs() << "LV: "
758+
<< (isFindLastIVRecurrenceKind(Kind) ? "FindLastIV"
759+
: "FindFirstIV")
760+
<< " valid range is " << ValidRange
735761
<< ", and the range of " << *AR << " is " << IVRange
736762
<< "\n");
737763

738764
// Ensure the induction variable does not wrap around by verifying that
739765
// its range is fully contained within the valid range.
740766
return ValidRange.contains(IVRange);
741767
};
768+
if (isFindLastIVRecurrenceKind(Kind)) {
769+
if (CheckRange(true))
770+
return RecurKind::FindLastIVSMax;
771+
if (CheckRange(false))
772+
return RecurKind::FindLastIVUMax;
773+
return std::nullopt;
774+
}
775+
assert(isFindFirstIVRecurrenceKind(Kind) &&
776+
"Kind must either be a FindLastIV or FindFirstIV");
777+
742778
if (CheckRange(true))
743-
return RecurKind::FindLastIVSMax;
744-
if (CheckRange(false))
745-
return RecurKind::FindLastIVUMax;
779+
return RecurKind::FindFirstIVSMin;
746780
return std::nullopt;
747781
};
748782

@@ -888,8 +922,8 @@ RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr(
888922
if (Kind == RecurKind::FAdd || Kind == RecurKind::FMul ||
889923
Kind == RecurKind::Add || Kind == RecurKind::Mul)
890924
return isConditionalRdxPattern(I);
891-
if (isFindLastIVRecurrenceKind(Kind) && SE)
892-
return isFindLastIVPattern(L, OrigPhi, I, *SE);
925+
if (isFindIVRecurrenceKind(Kind) && SE)
926+
return isFindIVPattern(Kind, L, OrigPhi, I, *SE);
893927
[[fallthrough]];
894928
case Instruction::FCmp:
895929
case Instruction::ICmp:
@@ -1003,6 +1037,11 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
10031037
LLVM_DEBUG(dbgs() << "Found a FindLastIV reduction PHI." << *Phi << "\n");
10041038
return true;
10051039
}
1040+
if (AddReductionVar(Phi, RecurKind::FindFirstIVSMin, TheLoop, FMF, RedDes, DB,
1041+
AC, DT, SE)) {
1042+
LLVM_DEBUG(dbgs() << "Found a FindFirstIV reduction PHI." << *Phi << "\n");
1043+
return true;
1044+
}
10061045
if (AddReductionVar(Phi, RecurKind::FMul, TheLoop, FMF, RedDes, DB, AC, DT,
10071046
SE)) {
10081047
LLVM_DEBUG(dbgs() << "Found an FMult reduction PHI." << *Phi << "\n");
@@ -1150,6 +1189,7 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
11501189
case RecurKind::Mul:
11511190
return Instruction::Mul;
11521191
case RecurKind::AnyOf:
1192+
case RecurKind::FindFirstIVSMin:
11531193
case RecurKind::FindLastIVSMax:
11541194
case RecurKind::FindLastIVUMax:
11551195
case RecurKind::Or:

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,9 +1227,12 @@ Value *llvm::createFindLastIVReduction(IRBuilderBase &Builder, Value *Src,
12271227
RecurKind RdxKind, Value *Start,
12281228
Value *Sentinel) {
12291229
bool IsSigned = RecurrenceDescriptor::isSignedRecurrenceKind(RdxKind);
1230-
Value *MaxRdx = Src->getType()->isVectorTy()
1231-
? Builder.CreateIntMaxReduce(Src, IsSigned)
1232-
: Src;
1230+
Value *MaxRdx =
1231+
Src->getType()->isVectorTy()
1232+
? (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RdxKind)
1233+
? Builder.CreateIntMaxReduce(Src, IsSigned)
1234+
: Builder.CreateIntMinReduce(Src, IsSigned))
1235+
: Src;
12331236
// Correct the final reduction result back to the start value if the maximum
12341237
// reduction is sentinel value.
12351238
Value *Cmp =
@@ -1324,8 +1327,8 @@ Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
13241327
Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
13251328
RecurKind Kind, Value *Mask, Value *EVL) {
13261329
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
1327-
!RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) &&
1328-
"AnyOf or FindLastIV reductions are not supported.");
1330+
!RecurrenceDescriptor::isFindIVRecurrenceKind(Kind) &&
1331+
"AnyOf, FindFirstIV and FindLastIV reductions are not supported.");
13291332
Intrinsic::ID Id = getReductionIntrinsicID(Kind);
13301333
auto VPID = VPIntrinsic::getForIntrinsic(Id);
13311334
assert(VPReductionIntrinsic::isVPReduction(VPID) &&

0 commit comments

Comments
 (0)