Skip to content

Commit 99addbf

Browse files
authored
[LV] Vectorize selecting last IV of min/max element. (#141431)
Add support for vectorizing loops that select the index of the minimum or maximum element. The patch implements vectorizing those patterns by combining Min/Max and FindFirstIV reductions. It extends matching Min/Max reductions to allow in-loop users that are FindLastIV reductions. It records a flag indicating that the Min/Max reduction is used by another reduction. The extra user is then check as part of the new `handleMultiUseReductions` VPlan transformation. It processes any reduction that has other reduction users. The reduction using the min/max reduction currently must be a FindLastIV reduction, which needs adjusting to compute the correct result: 1. We need to find the last IV for which the condition based on the min/max reduction is true, 2. Compare the partial min/max reduction result to its final value and, 3. Select the lanes of the partial FindLastIV reductions which correspond to the lanes matching the min/max reduction result. Depends on #140451 PR: #141431
1 parent e99d8ad commit 99addbf

File tree

15 files changed

+1283
-134
lines changed

15 files changed

+1283
-134
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,17 @@ class RecurrenceDescriptor {
9595
RecurKind K, FastMathFlags FMF, Instruction *ExactFP,
9696
Type *RT, bool Signed, bool Ordered,
9797
SmallPtrSetImpl<Instruction *> &CI,
98-
unsigned MinWidthCastToRecurTy)
98+
unsigned MinWidthCastToRecurTy,
99+
bool PhiHasUsesOutsideReductionChain = false)
99100
: IntermediateStore(Store), StartValue(Start), LoopExitInstr(Exit),
100101
Kind(K), FMF(FMF), ExactFPMathInst(ExactFP), RecurrenceType(RT),
101102
IsSigned(Signed), IsOrdered(Ordered),
103+
PhiHasUsesOutsideReductionChain(PhiHasUsesOutsideReductionChain),
102104
MinWidthCastToRecurrenceType(MinWidthCastToRecurTy) {
103105
CastInsts.insert_range(CI);
106+
assert(
107+
(!PhiHasUsesOutsideReductionChain || isMinMaxRecurrenceKind(K)) &&
108+
"Only min/max recurrences are allowed to have multiple uses currently");
104109
}
105110

106111
/// This POD struct holds information about a potential recurrence operation.
@@ -339,6 +344,13 @@ class RecurrenceDescriptor {
339344
/// Expose an ordered FP reduction to the instance users.
340345
bool isOrdered() const { return IsOrdered; }
341346

347+
/// Returns true if the reduction PHI has any uses outside the reduction
348+
/// chain. This is relevant for min/max reductions that are part of a
349+
/// FindLastIV pattern.
350+
bool hasUsesOutsideReductionChain() const {
351+
return PhiHasUsesOutsideReductionChain;
352+
}
353+
342354
/// Attempts to find a chain of operations from Phi to LoopExitInst that can
343355
/// be treated as a set of reductions instructions for in-loop reductions.
344356
LLVM_ABI SmallVector<Instruction *, 4> getReductionOpChain(PHINode *Phi,
@@ -376,6 +388,10 @@ class RecurrenceDescriptor {
376388
// Currently only a non-reassociative FAdd can be considered in-order,
377389
// if it is also the only FAdd in the PHI's use chain.
378390
bool IsOrdered = false;
391+
// True if the reduction PHI has in-loop users outside the reduction chain.
392+
// This is relevant for min/max reductions that are part of a FindLastIV
393+
// pattern.
394+
bool PhiHasUsesOutsideReductionChain = false;
379395
// Instructions used for type-promoting the recurrence.
380396
SmallPtrSet<Instruction *, 8> CastInsts;
381397
// The minimum width used by the recurrence.

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,52 @@ static bool checkOrderedReduction(RecurKind Kind, Instruction *ExactFPMathInst,
216216
return true;
217217
}
218218

219+
/// Returns true if \p Phi is a min/max reduction matching \p Kind where \p Phi
220+
/// is used outside the reduction chain. This is common for loops selecting the
221+
/// index of a minimum/maximum value (argmin/argmax).
222+
static bool isMinMaxReductionPhiWithUsersOutsideReductionChain(
223+
PHINode *Phi, RecurKind Kind, Loop *TheLoop, RecurrenceDescriptor &RedDes) {
224+
BasicBlock *Latch = TheLoop->getLoopLatch();
225+
if (!Latch)
226+
return false;
227+
228+
assert(Phi->getNumIncomingValues() == 2 && "phi must have 2 incoming values");
229+
Value *Inc = Phi->getIncomingValueForBlock(Latch);
230+
if (Phi->hasOneUse() || !Inc->hasOneUse() ||
231+
!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(Kind))
232+
return false;
233+
234+
Value *A, *B;
235+
bool IsMinMax = [&]() {
236+
switch (Kind) {
237+
case RecurKind::UMax:
238+
return match(Inc, m_UMax(m_Value(A), m_Value(B)));
239+
case RecurKind::UMin:
240+
return match(Inc, m_UMin(m_Value(A), m_Value(B)));
241+
case RecurKind::SMax:
242+
return match(Inc, m_SMax(m_Value(A), m_Value(B)));
243+
case RecurKind::SMin:
244+
return match(Inc, m_SMin(m_Value(A), m_Value(B)));
245+
default:
246+
llvm_unreachable("all min/max kinds must be handled");
247+
}
248+
}();
249+
if (!IsMinMax)
250+
return false;
251+
252+
if (A == B || (A != Phi && B != Phi))
253+
return false;
254+
255+
SmallPtrSet<Instruction *, 4> CastInsts;
256+
Value *RdxStart = Phi->getIncomingValueForBlock(TheLoop->getLoopPreheader());
257+
RedDes =
258+
RecurrenceDescriptor(RdxStart, /*Exit=*/nullptr, /*Store=*/nullptr, Kind,
259+
FastMathFlags(), /*ExactFP=*/nullptr, Phi->getType(),
260+
/*Signed=*/false, /*Ordered=*/false, CastInsts,
261+
/*MinWidthCastToRecurTy=*/-1U, /*PhiMultiUse=*/true);
262+
return true;
263+
}
264+
219265
bool RecurrenceDescriptor::AddReductionVar(
220266
PHINode *Phi, RecurKind Kind, Loop *TheLoop, FastMathFlags FuncFMF,
221267
RecurrenceDescriptor &RedDes, DemandedBits *DB, AssumptionCache *AC,
@@ -227,6 +273,11 @@ bool RecurrenceDescriptor::AddReductionVar(
227273
if (Phi->getParent() != TheLoop->getHeader())
228274
return false;
229275

276+
// Check for min/max reduction variables that feed other users in the loop.
277+
if (isMinMaxReductionPhiWithUsersOutsideReductionChain(Phi, Kind, TheLoop,
278+
RedDes))
279+
return true;
280+
230281
// Obtain the reduction start value from the value that comes from the loop
231282
// preheader.
232283
Value *RdxStart = Phi->getIncomingValueForBlock(TheLoop->getLoopPreheader());

llvm/lib/Transforms/Utils/LoopUnroll.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,6 +1254,8 @@ llvm::canParallelizeReductionWhenUnrolling(PHINode &Phi, Loop *L,
12541254
/*DemandedBits=*/nullptr,
12551255
/*AC=*/nullptr, /*DT=*/nullptr, SE))
12561256
return std::nullopt;
1257+
if (RdxDesc.hasUsesOutsideReductionChain())
1258+
return std::nullopt;
12571259
RecurKind RK = RdxDesc.getRecurrenceKind();
12581260
// Skip unsupported reductions.
12591261
// TODO: Handle additional reductions, including min-max reductions.

llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,11 @@ bool LoopVectorizationLegality::canVectorizeInstr(Instruction &I) {
877877
Requirements->addExactFPMathInst(RedDes.getExactFPMathInst());
878878
AllowedExit.insert(RedDes.getLoopExitInstr());
879879
Reductions[Phi] = RedDes;
880+
assert((!RedDes.hasUsesOutsideReductionChain() ||
881+
RecurrenceDescriptor::isMinMaxRecurrenceKind(
882+
RedDes.getRecurrenceKind())) &&
883+
"Only min/max recurrences are allowed to have multiple uses "
884+
"currently");
880885
return true;
881886
}
882887

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6593,6 +6593,11 @@ void LoopVectorizationCostModel::collectInLoopReductions() {
65936593
PHINode *Phi = Reduction.first;
65946594
const RecurrenceDescriptor &RdxDesc = Reduction.second;
65956595

6596+
// Multi-use reductions (e.g., used in FindLastIV patterns) are handled
6597+
// separately and should not be considered for in-loop reductions.
6598+
if (RdxDesc.hasUsesOutsideReductionChain())
6599+
continue;
6600+
65966601
// We don't collect reductions that are type promoted (yet).
65976602
if (RdxDesc.getRecurrenceType() != Phi->getType())
65986603
continue;
@@ -7998,9 +8003,10 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
79988003
MapVector<Instruction *,
79998004
SmallVector<std::pair<PartialReductionChain, unsigned>>>
80008005
ChainsByPhi;
8001-
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8002-
getScaledReductions(Phi, RdxDesc.getLoopExitInstr(), Range,
8003-
ChainsByPhi[Phi]);
8006+
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) {
8007+
if (Instruction *RdxExitInstr = RdxDesc.getLoopExitInstr())
8008+
getScaledReductions(Phi, RdxExitInstr, Range, ChainsByPhi[Phi]);
8009+
}
80048010

80058011
// A partial reduction is invalid if any of its extends are used by
80068012
// something that isn't another partial reduction. This is because the
@@ -8221,7 +8227,8 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R,
82218227
PhiRecipe = new VPReductionPHIRecipe(
82228228
Phi, RdxDesc.getRecurrenceKind(), *StartV,
82238229
getReductionStyle(UseInLoopReduction, UseOrderedReductions,
8224-
ScaleFactor));
8230+
ScaleFactor),
8231+
RdxDesc.hasUsesOutsideReductionChain());
82258232
} else {
82268233
// TODO: Currently fixed-order recurrences are modeled as chains of
82278234
// first-order recurrences. If there are no users of the intermediate
@@ -8555,6 +8562,11 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
85558562
// Adjust the recipes for any inloop reductions.
85568563
adjustRecipesForReductions(Plan, RecipeBuilder, Range.Start);
85578564

8565+
// Apply mandatory transformation to handle reductions with multiple in-loop
8566+
// uses if possible, bail out otherwise.
8567+
if (!VPlanTransforms::runPass(VPlanTransforms::handleMultiUseReductions,
8568+
*Plan))
8569+
return nullptr;
85588570
// Apply mandatory transformation to handle FP maxnum/minnum reduction with
85598571
// NaNs if possible, bail out otherwise.
85608572
if (!VPlanTransforms::runPass(VPlanTransforms::handleMaxMinNumReductions,

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2071,6 +2071,9 @@ class LLVM_ABI_FOR_TEST VPHeaderPHIRecipe : public VPSingleDefRecipe,
20712071
static inline bool classof(const VPValue *V) {
20722072
return isa<VPHeaderPHIRecipe>(V->getDefiningRecipe());
20732073
}
2074+
static inline bool classof(const VPSingleDefRecipe *R) {
2075+
return isa<VPHeaderPHIRecipe>(static_cast<const VPRecipeBase *>(R));
2076+
}
20742077

20752078
/// Generate the phi nodes.
20762079
void execute(VPTransformState &State) override = 0;
@@ -2136,7 +2139,7 @@ class VPWidenInductionRecipe : public VPHeaderPHIRecipe {
21362139
return R && classof(R);
21372140
}
21382141

2139-
static inline bool classof(const VPHeaderPHIRecipe *R) {
2142+
static inline bool classof(const VPSingleDefRecipe *R) {
21402143
return classof(static_cast<const VPRecipeBase *>(R));
21412144
}
21422145

@@ -2432,19 +2435,27 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
24322435

24332436
ReductionStyle Style;
24342437

2438+
/// The phi is part of a multi-use reduction (e.g., used in FindLastIV
2439+
/// patterns for argmin/argmax).
2440+
/// TODO: Also support cases where the phi itself has a single use, but its
2441+
/// compare has multiple uses.
2442+
bool HasUsesOutsideReductionChain;
2443+
24352444
public:
24362445
/// Create a new VPReductionPHIRecipe for the reduction \p Phi.
24372446
VPReductionPHIRecipe(PHINode *Phi, RecurKind Kind, VPValue &Start,
2438-
ReductionStyle Style)
2447+
ReductionStyle Style,
2448+
bool HasUsesOutsideReductionChain = false)
24392449
: VPHeaderPHIRecipe(VPDef::VPReductionPHISC, Phi, &Start), Kind(Kind),
2440-
Style(Style) {}
2450+
Style(Style),
2451+
HasUsesOutsideReductionChain(HasUsesOutsideReductionChain) {}
24412452

24422453
~VPReductionPHIRecipe() override = default;
24432454

24442455
VPReductionPHIRecipe *clone() override {
24452456
auto *R = new VPReductionPHIRecipe(
24462457
dyn_cast_or_null<PHINode>(getUnderlyingValue()), getRecurrenceKind(),
2447-
*getOperand(0), Style);
2458+
*getOperand(0), Style, HasUsesOutsideReductionChain);
24482459
R->addOperand(getBackedgeValue());
24492460
return R;
24502461
}
@@ -2481,6 +2492,11 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
24812492
/// Returns true if the reduction outputs a vector with a scaled down VF.
24822493
bool isPartialReduction() const { return getVFScaleFactor() > 1; }
24832494

2495+
/// Returns true, if the phi is part of a multi-use reduction.
2496+
bool hasUsesOutsideReductionChain() const {
2497+
return HasUsesOutsideReductionChain;
2498+
}
2499+
24842500
/// Returns true if the recipe only uses the first lane of operand \p Op.
24852501
bool usesFirstLaneOnly(const VPValue *Op) const override {
24862502
assert(is_contained(operands(), Op) &&

0 commit comments

Comments
 (0)