Skip to content

Commit 113aa0f

Browse files
committed
[LV] Support argmin/argmax with strict predicates.
Extend handleMultiUseReductions to support strict predicates (>, <), matching the first index instead of the last for non-strict predicates. Builds on top of #141431. For strict predicates is detected, the transformation converts the FindLastIV reduction to FindFirstIV by: 1. Checking the IV range to ensure it does not include the sentinel value (max). 2. Creating a new reduction with the appropriate FindFirstIV kind (FindFirstIVSMin or FindFirstIVUMin based on the IV range) 3. Replacing the old reduction recipe with the new one
1 parent 2864afb commit 113aa0f

File tree

14 files changed

+1115
-247
lines changed

14 files changed

+1115
-247
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class Loop;
2828
class PredicatedScalarEvolution;
2929
class ScalarEvolution;
3030
class SCEV;
31+
class SCEVAddRecExpr;
3132
class StoreInst;
3233

3334
/// These are the kinds of recurrences that we support.
@@ -310,6 +311,11 @@ class RecurrenceDescriptor {
310311
isFindLastIVRecurrenceKind(Kind);
311312
}
312313

314+
/// Returns true if \p AR's range is valid for either FindFirstIV or
315+
/// FindLastIV reductions i.e. if the sentinel value is outside \p AR's range.
316+
static bool isValidIVRangeForFindIV(const SCEVAddRecExpr *AR, bool IsSigned,
317+
bool IsFindFirstIV, ScalarEvolution &SE);
318+
313319
/// Returns the type of the recurrence. This type can be narrower than the
314320
/// actual type of the Phi if the recurrence has been type-promoted.
315321
Type *getRecurrenceType() const { return RecurrenceType; }

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,36 @@ RecurrenceDescriptor::isAnyOfPattern(Loop *Loop, PHINode *OrigPhi,
715715
return InstDesc(I, RecurKind::AnyOf);
716716
}
717717

718+
bool RecurrenceDescriptor::isValidIVRangeForFindIV(const SCEVAddRecExpr *AR,
719+
bool IsSigned,
720+
bool IsFindFirstIV,
721+
ScalarEvolution &SE) {
722+
const ConstantRange IVRange =
723+
IsSigned ? SE.getSignedRange(AR) : SE.getUnsignedRange(AR);
724+
unsigned NumBits = AR->getType()->getIntegerBitWidth();
725+
ConstantRange ValidRange = ConstantRange::getEmpty(NumBits);
726+
727+
if (IsFindFirstIV) {
728+
if (IsSigned)
729+
ValidRange =
730+
ConstantRange::getNonEmpty(APInt::getSignedMinValue(NumBits),
731+
APInt::getSignedMaxValue(NumBits) - 1);
732+
else
733+
ValidRange = ConstantRange::getNonEmpty(APInt::getMinValue(NumBits),
734+
APInt::getMaxValue(NumBits) - 1);
735+
} else {
736+
APInt Sentinel = IsSigned ? APInt::getSignedMinValue(NumBits)
737+
: APInt::getMinValue(NumBits);
738+
ValidRange = ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
739+
}
740+
741+
LLVM_DEBUG(dbgs() << "LV: " << (IsFindFirstIV ? "FindFirstIV" : "FindLastIV")
742+
<< " valid range is " << ValidRange << ", and the range of "
743+
<< *AR << " is " << IVRange << "\n");
744+
745+
return ValidRange.contains(IVRange);
746+
}
747+
718748
// We are looking for loops that do something like this:
719749
// int r = 0;
720750
// for (int i = 0; i < n; i++) {
@@ -792,49 +822,24 @@ RecurrenceDescriptor::isFindIVPattern(RecurKind Kind, Loop *TheLoop,
792822
// [Signed|Unsigned]Max(<recurrence type>) for FindFirstIV.
793823
// TODO: This range restriction can be lifted by adding an additional
794824
// virtual OR reduction.
795-
auto CheckRange = [&](bool IsSigned) {
796-
const ConstantRange IVRange =
797-
IsSigned ? SE.getSignedRange(AR) : SE.getUnsignedRange(AR);
798-
unsigned NumBits = Ty->getIntegerBitWidth();
799-
ConstantRange ValidRange = ConstantRange::getEmpty(NumBits);
800-
if (isFindLastIVRecurrenceKind(Kind)) {
801-
APInt Sentinel = IsSigned ? APInt::getSignedMinValue(NumBits)
802-
: APInt::getMinValue(NumBits);
803-
ValidRange = ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
804-
} else {
805-
if (IsSigned)
806-
ValidRange =
807-
ConstantRange::getNonEmpty(APInt::getSignedMinValue(NumBits),
808-
APInt::getSignedMaxValue(NumBits) - 1);
809-
else
810-
ValidRange = ConstantRange::getNonEmpty(
811-
APInt::getMinValue(NumBits), APInt::getMaxValue(NumBits) - 1);
812-
}
813-
814-
LLVM_DEBUG(dbgs() << "LV: "
815-
<< (isFindLastIVRecurrenceKind(Kind) ? "FindLastIV"
816-
: "FindFirstIV")
817-
<< " valid range is " << ValidRange
818-
<< ", and the range of " << *AR << " is " << IVRange
819-
<< "\n");
820-
821-
// Ensure the induction variable does not wrap around by verifying that
822-
// its range is fully contained within the valid range.
823-
return ValidRange.contains(IVRange);
824-
};
825+
bool IsFindFirstIV = isFindFirstIVRecurrenceKind(Kind);
825826
if (isFindLastIVRecurrenceKind(Kind)) {
826-
if (CheckRange(true))
827+
if (RecurrenceDescriptor::isValidIVRangeForFindIV(
828+
cast<SCEVAddRecExpr>(AR), /*IsSigned=*/true, IsFindFirstIV, SE))
827829
return RecurKind::FindLastIVSMax;
828-
if (CheckRange(false))
830+
if (RecurrenceDescriptor::isValidIVRangeForFindIV(
831+
cast<SCEVAddRecExpr>(AR), /*IsSigned=*/false, IsFindFirstIV, SE))
829832
return RecurKind::FindLastIVUMax;
830833
return std::nullopt;
831834
}
832835
assert(isFindFirstIVRecurrenceKind(Kind) &&
833836
"Kind must either be a FindLastIV or FindFirstIV");
834837

835-
if (CheckRange(true))
838+
if (RecurrenceDescriptor::isValidIVRangeForFindIV(
839+
cast<SCEVAddRecExpr>(AR), /*IsSigned=*/true, IsFindFirstIV, SE))
836840
return RecurKind::FindFirstIVSMin;
837-
if (CheckRange(false))
841+
if (RecurrenceDescriptor::isValidIVRangeForFindIV(
842+
cast<SCEVAddRecExpr>(AR), /*IsSigned=*/false, IsFindFirstIV, SE))
838843
return RecurKind::FindFirstIVUMin;
839844
return std::nullopt;
840845
};

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8576,8 +8576,7 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
85768576

85778577
// Apply mandatory transformation to handle reductions with multiple in-loop
85788578
// uses if possible, bail out otherwise.
8579-
if (!VPlanTransforms::runPass(VPlanTransforms::handleMultiUseReductions,
8580-
*Plan))
8579+
if (!VPlanTransforms::handleMultiUseReductions(*Plan, *PSE.getSE(), OrigLoop))
85818580
return nullptr;
85828581
// Apply mandatory transformation to handle FP maxnum/minnum reduction with
85838582
// NaNs if possible, bail out otherwise.

llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313

1414
#include "LoopVectorizationPlanner.h"
1515
#include "VPlan.h"
16+
#include "VPlanAnalysis.h"
1617
#include "VPlanCFG.h"
1718
#include "VPlanDominatorTree.h"
1819
#include "VPlanPatternMatch.h"
1920
#include "VPlanTransforms.h"
21+
#include "VPlanUtils.h"
2022
#include "llvm/Analysis/LoopInfo.h"
2123
#include "llvm/Analysis/LoopIterator.h"
2224
#include "llvm/Analysis/ScalarEvolution.h"
25+
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
2326
#include "llvm/IR/InstrTypes.h"
2427
#include "llvm/IR/MDBuilder.h"
2528
#include "llvm/Transforms/Utils/LoopUtils.h"
@@ -997,7 +1000,48 @@ bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) {
9971000
return true;
9981001
}
9991002

1000-
bool VPlanTransforms::handleMultiUseReductions(VPlan &Plan) {
1003+
/// Try to convert FindLastIV to FindFirstIV reduction when using a strict
1004+
/// predicate. Returns the new FindFirstIVPhiR on success, nullptr on failure.
1005+
static VPReductionPHIRecipe *
1006+
tryConvertToFindFirstIV(VPlan &Plan, VPReductionPHIRecipe *FindLastIVPhiR,
1007+
VPValue *IVOp, ScalarEvolution &SE, const Loop *L) {
1008+
Type *Ty = VPTypeAnalysis(Plan).inferScalarType(FindLastIVPhiR);
1009+
unsigned NumBits = Ty->getIntegerBitWidth();
1010+
1011+
// Determine the reduction kind and sentinel based on the IV range.
1012+
RecurKind NewKind;
1013+
VPValue *NewSentinel;
1014+
auto *AR = cast<SCEVAddRecExpr>(vputils::getSCEVExprForVPValue(IVOp, SE, L));
1015+
if (RecurrenceDescriptor::isValidIVRangeForFindIV(
1016+
AR, /*IsSigned=*/true, /*IsFindFirstIV=*/true, SE)) {
1017+
NewKind = RecurKind::FindFirstIVSMin;
1018+
NewSentinel = Plan.getConstantInt(APInt::getSignedMaxValue(NumBits));
1019+
} else if (RecurrenceDescriptor::isValidIVRangeForFindIV(
1020+
AR, /*IsSigned=*/false, /*IsFindFirstIV=*/true, SE)) {
1021+
NewKind = RecurKind::FindFirstIVUMin;
1022+
NewSentinel = Plan.getConstantInt(APInt::getMaxValue(NumBits));
1023+
} else {
1024+
return nullptr;
1025+
}
1026+
1027+
// Create the new FindFirstIV reduction recipe.
1028+
assert(!FindLastIVPhiR->isInLoop() && !FindLastIVPhiR->isOrdered());
1029+
ReductionStyle Style = RdxUnordered{FindLastIVPhiR->getVFScaleFactor()};
1030+
auto *FindFirstIVPhiR =
1031+
new VPReductionPHIRecipe(nullptr, NewKind, *NewSentinel, Style,
1032+
FindLastIVPhiR->hasUsesOutsideReductionChain());
1033+
FindFirstIVPhiR->addOperand(FindLastIVPhiR->getBackedgeValue());
1034+
1035+
FindFirstIVPhiR->insertBefore(FindLastIVPhiR);
1036+
VPInstruction *FindLastIVResult =
1037+
findUserOf<VPInstruction::ComputeFindIVResult>(FindLastIVPhiR);
1038+
FindLastIVPhiR->replaceAllUsesWith(FindFirstIVPhiR);
1039+
FindLastIVResult->setOperand(2, NewSentinel);
1040+
return FindFirstIVPhiR;
1041+
}
1042+
1043+
bool VPlanTransforms::handleMultiUseReductions(VPlan &Plan, ScalarEvolution &SE,
1044+
const Loop *L) {
10011045
for (auto &PhiR : make_early_inc_range(
10021046
Plan.getVectorLoopRegion()->getEntryBasicBlock()->phis())) {
10031047
auto *MinMaxPhiR = dyn_cast<VPReductionPHIRecipe>(&PhiR);
@@ -1080,33 +1124,41 @@ bool VPlanTransforms::handleMultiUseReductions(VPlan &Plan) {
10801124
FindIVPhiR->getRecurrenceKind()))
10811125
return false;
10821126

1127+
assert(!FindIVPhiR->isInLoop() && !FindIVPhiR->isOrdered() &&
1128+
"cannot handle inloop/ordered reductions yet");
1129+
10831130
// TODO: Support cases where IVOp is the IV increment.
10841131
if (!match(IVOp, m_TruncOrSelf(m_VPValue(IVOp))) ||
10851132
!isa<VPWidenIntOrFpInductionRecipe>(IVOp))
10861133
return false;
10871134

1088-
CmpInst::Predicate RdxPredicate = [RdxKind]() {
1135+
// Check if the predicate is compatible with the reduction kind.
1136+
bool IsValidPredicate = [RdxKind, Pred]() {
10891137
switch (RdxKind) {
10901138
case RecurKind::UMin:
1091-
return CmpInst::ICMP_UGE;
1139+
return Pred == CmpInst::ICMP_UGE || Pred == CmpInst::ICMP_UGT;
10921140
case RecurKind::UMax:
1093-
return CmpInst::ICMP_ULE;
1141+
return Pred == CmpInst::ICMP_ULE || Pred == CmpInst::ICMP_ULT;
10941142
case RecurKind::SMax:
1095-
return CmpInst::ICMP_SLE;
1143+
return Pred == CmpInst::ICMP_SLE || Pred == CmpInst::ICMP_SLT;
10961144
case RecurKind::SMin:
1097-
return CmpInst::ICMP_SGE;
1145+
return Pred == CmpInst::ICMP_SGE || Pred == CmpInst::ICMP_SGT;
10981146
default:
10991147
llvm_unreachable("unhandled recurrence kind");
11001148
}
11011149
}();
11021150

1103-
// TODO: Strict predicates need to find the first IV value for which the
1104-
// predicate holds, not the last.
1105-
if (Pred != RdxPredicate)
1151+
if (!IsValidPredicate)
11061152
return false;
11071153

1108-
assert(!FindIVPhiR->isInLoop() && !FindIVPhiR->isOrdered() &&
1109-
"cannot handle inloop/ordered reductions yet");
1154+
// For strict predicates, transform try to convert FindLastIV to
1155+
// FindFirstIV.
1156+
bool IsStrictPredicate = ICmpInst::isLT(Pred) || ICmpInst::isGT(Pred);
1157+
if (IsStrictPredicate) {
1158+
FindIVPhiR = tryConvertToFindFirstIV(Plan, FindIVPhiR, IVOp, SE, L);
1159+
if (!FindIVPhiR)
1160+
return false;
1161+
}
11101162

11111163
// The reduction using MinMaxPhiR needs adjusting to compute the correct
11121164
// result:

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ bool VPRecipeBase::mayHaveSideEffects() const {
163163
return cast<VPExpressionRecipe>(this)->mayHaveSideEffects();
164164
case VPDerivedIVSC:
165165
case VPFirstOrderRecurrencePHISC:
166+
case VPReductionPHISC:
166167
case VPPredInstPHISC:
167168
case VPVectorEndPointerSC:
168169
return false;

llvm/lib/Transforms/Vectorize/VPlanTransforms.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,10 @@ struct VPlanTransforms {
146146
const TargetLibraryInfo &TLI);
147147

148148
/// Try to legalize reductions with multiple in-loop uses. Currently only
149-
/// min/max reductions used by FindLastIV reductions are supported. Otherwise
150-
/// return false.
151-
static bool handleMultiUseReductions(VPlan &Plan);
149+
/// min/max reductions used by FindLastIV and FindFirstIV reductions are
150+
/// supported. Otherwise return false.
151+
static bool handleMultiUseReductions(VPlan &Plan, ScalarEvolution &SE,
152+
const Loop *L);
152153

153154
/// Try to have all users of fixed-order recurrences appear after the recipe
154155
/// defining their previous value, by either sinking users or hoisting recipes

0 commit comments

Comments
 (0)