Skip to content

Commit fce52d0

Browse files
committed
Unify Find recurrence detection.
1 parent 142d881 commit fce52d0

File tree

5 files changed

+33
-54
lines changed

5 files changed

+33
-54
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,8 @@ class RecurrenceDescriptor {
181181
/// where one of (X, Y) is an increasing (FindLastIV) or decreasing
182182
/// (FindFirstIV) loop induction variable, or an arbitrary integer value
183183
/// (FindLast), and the other is a PHI value.
184-
LLVM_ABI static InstDesc isFindPattern(RecurKind Kind, Loop *TheLoop,
185-
PHINode *OrigPhi, Instruction *I,
186-
ScalarEvolution &SE);
184+
LLVM_ABI static InstDesc isFindPattern(Loop *TheLoop, PHINode *OrigPhi,
185+
Instruction *I, ScalarEvolution &SE);
187186

188187
/// Returns a struct describing if the instruction is a
189188
/// Select(FCmp(X, Y), (Z = X op PHINode), PHINode) instruction pattern.
@@ -314,6 +313,10 @@ class RecurrenceDescriptor {
314313
return Kind == RecurKind::FindLast;
315314
}
316315

316+
static bool isFindRecurrenceKind(RecurKind Kind) {
317+
return isFindLastRecurrenceKind(Kind) || isFindIVRecurrenceKind(Kind);
318+
}
319+
317320
/// Returns the type of the recurrence. This type can be narrower than the
318321
/// actual type of the Phi if the recurrence has been type-promoted.
319322
Type *getRecurrenceType() const { return RecurrenceType; }

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 18 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -707,9 +707,8 @@ RecurrenceDescriptor::isAnyOfPattern(Loop *Loop, PHINode *OrigPhi,
707707
// will either be the initial value (0) if the condition was never met, or the
708708
// value of a[i] in the most recent loop iteration where the condition was met.
709709
RecurrenceDescriptor::InstDesc
710-
RecurrenceDescriptor::isFindPattern(RecurKind Kind, Loop *TheLoop,
711-
PHINode *OrigPhi, Instruction *I,
712-
ScalarEvolution &SE) {
710+
RecurrenceDescriptor::isFindPattern(Loop *TheLoop, PHINode *OrigPhi,
711+
Instruction *I, ScalarEvolution &SE) {
713712
// TODO: Support the vectorization of FindLastIV when the reduction phi is
714713
// used by more than one select instruction. This vectorization is only
715714
// performed when the SCEV of each increasing induction variable used by the
@@ -730,17 +729,6 @@ RecurrenceDescriptor::isFindPattern(RecurKind Kind, Loop *TheLoop,
730729
m_Value(NonRdxPhi)))))
731730
return InstDesc(false, I);
732731

733-
if (isFindLastRecurrenceKind(Kind)) {
734-
// Must be an integer scalar.
735-
Type *Type = OrigPhi->getType();
736-
if (!Type->isIntegerTy())
737-
return InstDesc(false, I);
738-
739-
// FIXME: Support more complex patterns, including multiple selects.
740-
// The Select must be used only outside the loop and by the PHI.
741-
return InstDesc(I, RecurKind::FindLast);
742-
}
743-
744732
// Returns either FindFirstIV/FindLastIV, if such a pattern is found, or
745733
// std::nullopt.
746734
auto GetRecurKind = [&](Value *V) -> std::optional<RecurKind> {
@@ -754,8 +742,9 @@ RecurrenceDescriptor::isFindPattern(RecurKind Kind, Loop *TheLoop,
754742
m_SpecificLoop(TheLoop))))
755743
return std::nullopt;
756744

757-
if ((isFindFirstIVRecurrenceKind(Kind) && !SE.isKnownNegative(Step)) ||
758-
(isFindLastIVRecurrenceKind(Kind) && !SE.isKnownPositive(Step)))
745+
// We must have a known positive or negative step for FindIV
746+
const bool PositiveStep = SE.isKnownPositive(Step);
747+
if (!PositiveStep && !SE.isKnownNegative(Step))
759748
return std::nullopt;
760749

761750
// Check if the minimum (FindLast) or maximum (FindFirst) value of the
@@ -771,7 +760,7 @@ RecurrenceDescriptor::isFindPattern(RecurKind Kind, Loop *TheLoop,
771760
IsSigned ? SE.getSignedRange(AR) : SE.getUnsignedRange(AR);
772761
unsigned NumBits = Ty->getIntegerBitWidth();
773762
ConstantRange ValidRange = ConstantRange::getEmpty(NumBits);
774-
if (isFindLastIVRecurrenceKind(Kind)) {
763+
if (PositiveStep) {
775764
APInt Sentinel = IsSigned ? APInt::getSignedMinValue(NumBits)
776765
: APInt::getMinValue(NumBits);
777766
ValidRange = ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
@@ -785,26 +774,22 @@ RecurrenceDescriptor::isFindPattern(RecurKind Kind, Loop *TheLoop,
785774
APInt::getMinValue(NumBits), APInt::getMaxValue(NumBits) - 1);
786775
}
787776

788-
LLVM_DEBUG(dbgs() << "LV: "
789-
<< (isFindLastIVRecurrenceKind(Kind) ? "FindLastIV"
790-
: "FindFirstIV")
791-
<< " valid range is " << ValidRange
792-
<< ", and the range of " << *AR << " is " << IVRange
793-
<< "\n");
777+
LLVM_DEBUG(
778+
dbgs() << "LV: " << (PositiveStep ? "FindLastIV" : "FindFirstIV")
779+
<< " valid range is " << ValidRange << ", and the range of "
780+
<< *AR << " is " << IVRange << "\n");
794781

795782
// Ensure the induction variable does not wrap around by verifying that
796783
// its range is fully contained within the valid range.
797784
return ValidRange.contains(IVRange);
798785
};
799-
if (isFindLastIVRecurrenceKind(Kind)) {
786+
if (PositiveStep) {
800787
if (CheckRange(true))
801788
return RecurKind::FindLastIVSMax;
802789
if (CheckRange(false))
803790
return RecurKind::FindLastIVUMax;
804791
return std::nullopt;
805792
}
806-
assert(isFindFirstIVRecurrenceKind(Kind) &&
807-
"Kind must either be a FindLastIV or FindFirstIV");
808793

809794
if (CheckRange(true))
810795
return RecurKind::FindFirstIVSMin;
@@ -816,7 +801,8 @@ RecurrenceDescriptor::isFindPattern(RecurKind Kind, Loop *TheLoop,
816801
if (auto RK = GetRecurKind(NonRdxPhi))
817802
return InstDesc(I, *RK);
818803

819-
return InstDesc(false, I);
804+
// If the recurrence is not specific to an IV, return a generic FindLast.
805+
return InstDesc(I, RecurKind::FindLast);
820806
}
821807

822808
RecurrenceDescriptor::InstDesc
@@ -950,8 +936,8 @@ RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr(
950936
Kind == RecurKind::Add || Kind == RecurKind::Mul ||
951937
Kind == RecurKind::Sub || Kind == RecurKind::AddChainWithSubs)
952938
return isConditionalRdxPattern(I);
953-
if ((isFindIVRecurrenceKind(Kind) || isFindLastRecurrenceKind(Kind)) && SE)
954-
return isFindPattern(Kind, L, OrigPhi, I, *SE);
939+
if (isFindRecurrenceKind(Kind) && SE)
940+
return isFindPattern(L, OrigPhi, I, *SE);
955941
[[fallthrough]];
956942
case Instruction::FCmp:
957943
case Instruction::ICmp:
@@ -1091,14 +1077,9 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
10911077
<< "\n");
10921078
return true;
10931079
}
1094-
if (AddReductionVar(Phi, RecurKind::FindLastIVSMax, TheLoop, FMF, RedDes, DB,
1095-
AC, DT, SE)) {
1096-
LLVM_DEBUG(dbgs() << "Found a FindLastIV reduction PHI." << *Phi << "\n");
1097-
return true;
1098-
}
1099-
if (AddReductionVar(Phi, RecurKind::FindFirstIVSMin, TheLoop, FMF, RedDes, DB,
1100-
AC, DT, SE)) {
1101-
LLVM_DEBUG(dbgs() << "Found a FindFirstIV reduction PHI." << *Phi << "\n");
1080+
if (AddReductionVar(Phi, RecurKind::FindLast, TheLoop, FMF, RedDes, DB, AC,
1081+
DT, SE)) {
1082+
LLVM_DEBUG(dbgs() << "Found a Find reduction PHI." << *Phi << "\n");
11021083
return true;
11031084
}
11041085
if (AddReductionVar(Phi, RecurKind::FMul, TheLoop, FMF, RedDes, DB, AC, DT,
@@ -1148,11 +1129,6 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
11481129
<< "\n");
11491130
return true;
11501131
}
1151-
if (AddReductionVar(Phi, RecurKind::FindLast, TheLoop, FMF, RedDes, DB, AC,
1152-
DT, SE)) {
1153-
LLVM_DEBUG(dbgs() << "Found a FindLast reduction PHI." << *Phi << "\n");
1154-
return true;
1155-
}
11561132
// Not a reduction of known type.
11571133
return false;
11581134
}
@@ -1278,7 +1254,6 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
12781254
case RecurKind::FMinimumNum:
12791255
return Instruction::FCmp;
12801256
case RecurKind::FindLast:
1281-
return Instruction::Select;
12821257
case RecurKind::AnyOf:
12831258
case RecurKind::FindFirstIVSMin:
12841259
case RecurKind::FindFirstIVUMin:

llvm/lib/Transforms/Utils/LoopUnroll.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,9 +1259,8 @@ llvm::canParallelizeReductionWhenUnrolling(PHINode &Phi, Loop *L,
12591259
// reductions.
12601260
if (!RecurrenceDescriptor::isIntegerRecurrenceKind(RK) ||
12611261
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
1262-
RecurrenceDescriptor::isFindIVRecurrenceKind(RK) ||
1263-
RecurrenceDescriptor::isMinMaxRecurrenceKind(RK) ||
1264-
RecurrenceDescriptor::isFindLastRecurrenceKind(RK))
1262+
RecurrenceDescriptor::isFindRecurrenceKind(RK) ||
1263+
RecurrenceDescriptor::isMinMaxRecurrenceKind(RK))
12651264
return std::nullopt;
12661265

12671266
if (RdxDesc.IntermediateStore)

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1491,7 +1491,7 @@ Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
14911491
Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
14921492
RecurKind Kind, Value *Mask, Value *EVL) {
14931493
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
1494-
!RecurrenceDescriptor::isFindIVRecurrenceKind(Kind) &&
1494+
!RecurrenceDescriptor::isFindRecurrenceKind(Kind) &&
14951495
"AnyOf and FindIV reductions are not supported.");
14961496
Intrinsic::ID Id = getReductionIntrinsicID(Kind);
14971497
auto VPID = VPIntrinsic::getForIntrinsic(Id);

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4617,10 +4617,12 @@ LoopVectorizationPlanner::selectInterleaveCount(VPlan &Plan, ElementCount VF,
46174617
IsaPred<VPReductionPHIRecipe>);
46184618

46194619
// FIXME: implement interleaving for FindLast transform correctly.
4620-
for (auto &[_, RdxDesc] : Legal->getReductionVars())
4621-
if (RecurrenceDescriptor::isFindLastRecurrenceKind(
4622-
RdxDesc.getRecurrenceKind()))
4623-
return 1;
4620+
if (any_of(make_second_range(Legal->getReductionVars()),
4621+
[](const RecurrenceDescriptor &RdxDesc) {
4622+
return RecurrenceDescriptor::isFindLastRecurrenceKind(
4623+
RdxDesc.getRecurrenceKind());
4624+
}))
4625+
return 1;
46244626

46254627
// If we did not calculate the cost for VF (because the user selected the VF)
46264628
// then we calculate the cost of VF here.

0 commit comments

Comments
 (0)