Skip to content

Commit 02d19db

Browse files
committed
[LV] Vectorize select min/max index.
Add support for vectorizing loops that select the index of the minimum or maximum element. The patch implements vectorizing those patterns by combining Min/Max and FindFirstIV reductions. It extends matching Min/Max reductions to allow in-loop users that are FindLastIV reductions. It records a flag indicating that the Min/Max reduction is used by another reduction. When creating reduction recipes, we process any reduction that has other reduction users. The reduction using the min/max reduction needs adjusting to compute the correct result: 1. We need to find the first IV for which the condition based on the min/max reduction is true, 2. Compare the partial min/max reduction result to its final value and, 3. Select the lanes of the partial FindLastIV reductions which correspond to the lanes matching the min/max reduction result.
1 parent 1a3e857 commit 02d19db

File tree

11 files changed

+1433
-241
lines changed

11 files changed

+1433
-241
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
798798
// For each block in the loop.
799799
for (BasicBlock *BB : TheLoop->blocks()) {
800800
// Scan the instructions in the block and look for hazards.
801+
PHINode *UnclassifiedPhi = nullptr;
801802
for (Instruction &I : *BB) {
802803
if (auto *Phi = dyn_cast<PHINode>(&I)) {
803804
Type *PhiTy = Phi->getType();
@@ -887,12 +888,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
887888
addInductionPhi(Phi, ID, AllowedExit);
888889
continue;
889890
}
890-
891-
reportVectorizationFailure("Found an unidentified PHI",
892-
"value that could not be identified as "
893-
"reduction is used outside the loop",
894-
"NonReductionValueUsedOutsideLoop", ORE, TheLoop, Phi);
895-
return false;
891+
UnclassifiedPhi = Phi;
896892
} // end of PHI handling
897893

898894
// We handle calls that:
@@ -1043,6 +1039,19 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
10431039
return false;
10441040
}
10451041
} // next instr.
1042+
if (UnclassifiedPhi && none_of(BB->phis(), [this](PHINode &P) {
1043+
auto I = Reductions.find(&P);
1044+
return I != Reductions.end() &&
1045+
RecurrenceDescriptor::isFindLastIVRecurrenceKind(
1046+
I->second.getRecurrenceKind());
1047+
})) {
1048+
reportVectorizationFailure("Found an unidentified PHI",
1049+
"value that could not be identified as "
1050+
"reduction is used outside the loop",
1051+
"NonReductionValueUsedOutsideLoop", ORE,
1052+
TheLoop, UnclassifiedPhi);
1053+
return false;
1054+
}
10461055
}
10471056

10481057
if (!PrimaryInduction) {

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7227,6 +7227,9 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
72277227
Value *StartV = getStartValueFromReductionResult(EpiRedResult);
72287228
Value *SentinelV = EpiRedResult->getOperand(2)->getLiveInIRValue();
72297229
using namespace llvm::PatternMatch;
7230+
MainResumeValue = cast<VPInstruction>(EpiRedHeaderPhi->getStartValue())
7231+
->getOperand(0)
7232+
->getUnderlyingValue();
72307233
Value *Cmp, *OrigResumeV, *CmpOp;
72317234
[[maybe_unused]] bool IsExpectedPattern =
72327235
match(MainResumeValue,
@@ -7237,7 +7240,11 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
72377240
((CmpOp == StartV && isGuaranteedNotToBeUndefOrPoison(CmpOp))));
72387241
assert(IsExpectedPattern && "Unexpected reduction resume pattern");
72397242
MainResumeValue = OrigResumeV;
7243+
} else {
7244+
if (auto *VPI = dyn_cast<VPInstruction>(EpiRedHeaderPhi->getStartValue()))
7245+
MainResumeValue = VPI->getOperand(0)->getUnderlyingValue();
72407246
}
7247+
72417248
PHINode *MainResumePhi = cast<PHINode>(MainResumeValue);
72427249

72437250
// When fixing reductions in the epilogue loop we should already have
@@ -8251,9 +8258,6 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R,
82518258
return Recipe;
82528259

82538260
VPHeaderPHIRecipe *PhiRecipe = nullptr;
8254-
assert((Legal->isReductionVariable(Phi) ||
8255-
Legal->isFixedOrderRecurrence(Phi)) &&
8256-
"can only widen reductions and fixed-order recurrences here");
82578261
VPValue *StartV = Operands[0];
82588262
if (Legal->isReductionVariable(Phi)) {
82598263
const RecurrenceDescriptor &RdxDesc = Legal->getRecurrenceDescriptor(Phi);
@@ -8266,12 +8270,17 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R,
82668270
PhiRecipe = new VPReductionPHIRecipe(
82678271
Phi, RdxDesc.getRecurrenceKind(), *StartV, CM.isInLoopReduction(Phi),
82688272
CM.useOrderedReductions(RdxDesc), ScaleFactor);
8269-
} else {
8273+
} else if (Legal->isFixedOrderRecurrence(Phi)) {
82708274
// TODO: Currently fixed-order recurrences are modeled as chains of
82718275
// first-order recurrences. If there are no users of the intermediate
82728276
// recurrences in the chain, the fixed order recurrence should be modeled
82738277
// directly, enabling more efficient codegen.
82748278
PhiRecipe = new VPFirstOrderRecurrencePHIRecipe(Phi, *StartV);
8279+
} else {
8280+
// Failed to identify phi as reduction or fixed-order recurrence. Keep the
8281+
// original VPWidenPHIRecipe for now, to be legalized later if possible.
8282+
setRecipe(Phi, R);
8283+
return nullptr;
82758284
}
82768285
// Add backedge value.
82778286
PhiRecipe->addOperand(Operands[1]);
@@ -8456,7 +8465,7 @@ static void addScalarResumePhis(VPRecipeBuilder &Builder, VPlan &Plan,
84568465
// TODO: Extract final value from induction recipe initially, optimize to
84578466
// pre-computed end value together in optimizeInductionExitUsers.
84588467
auto *VectorPhiR =
8459-
cast<VPHeaderPHIRecipe>(Builder.getRecipe(&ScalarPhiIRI->getIRPhi()));
8468+
cast<VPSingleDefRecipe>(Builder.getRecipe(&ScalarPhiIRI->getIRPhi()));
84608469
if (auto *WideIVR = dyn_cast<VPWidenInductionRecipe>(VectorPhiR)) {
84618470
if (VPInstruction *ResumePhi = addResumePhiRecipeForInduction(
84628471
WideIVR, VectorPHBuilder, ScalarPHBuilder, TypeInfo,
@@ -8478,7 +8487,7 @@ static void addScalarResumePhis(VPRecipeBuilder &Builder, VPlan &Plan,
84788487
// which for FORs is a vector whose last element needs to be extracted. The
84798488
// start value provides the value if the loop is bypassed.
84808489
bool IsFOR = isa<VPFirstOrderRecurrencePHIRecipe>(VectorPhiR);
8481-
auto *ResumeFromVectorLoop = VectorPhiR->getBackedgeValue();
8490+
auto *ResumeFromVectorLoop = VectorPhiR->getOperand(1);
84828491
assert(VectorRegion->getSingleSuccessor() == Plan.getMiddleBlock() &&
84838492
"Cannot handle loops with uncountable early exits");
84848493
if (IsFOR)
@@ -8487,7 +8496,7 @@ static void addScalarResumePhis(VPRecipeBuilder &Builder, VPlan &Plan,
84878496
"vector.recur.extract");
84888497
StringRef Name = IsFOR ? "scalar.recur.init" : "bc.merge.rdx";
84898498
auto *ResumePhiR = ScalarPHBuilder.createScalarPhi(
8490-
{ResumeFromVectorLoop, VectorPhiR->getStartValue()}, {}, Name);
8499+
{ResumeFromVectorLoop, VectorPhiR->getOperand(0)}, {}, Name);
84918500
ScalarPhiIRI->addOperand(ResumePhiR);
84928501
}
84938502
}
@@ -8758,6 +8767,8 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
87588767
VPRecipeBase *Recipe =
87598768
RecipeBuilder.tryToCreateWidenRecipe(SingleDef, Range);
87608769
if (!Recipe) {
8770+
if (isa<VPWidenPHIRecipe>(SingleDef))
8771+
continue;
87618772
SmallVector<VPValue *, 4> Operands(R.operands());
87628773
Recipe = RecipeBuilder.handleReplication(Instr, Operands, Range);
87638774
}
@@ -8820,6 +8831,10 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
88208831
// Adjust the recipes for any inloop reductions.
88218832
adjustRecipesForReductions(Plan, RecipeBuilder, Range.Start);
88228833

8834+
// Try to convert remaining VPWidenPHIRecipes to reduction recipes.
8835+
if (!VPlanTransforms::runPass(VPlanTransforms::legalizeUnclassifiedPhis,
8836+
*Plan))
8837+
return nullptr;
88238838
// Apply mandatory transformation to handle FP maxnum/minnum reduction with
88248839
// NaNs if possible, bail out otherwise.
88258840
if (!VPlanTransforms::runPass(VPlanTransforms::handleMaxMinNumReductions,
@@ -9292,6 +9307,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
92929307
PhiR->setOperand(0, StartV);
92939308
}
92949309
}
9310+
92959311
for (VPRecipeBase *R : ToDelete)
92969312
R->eraseFromParent();
92979313

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1896,7 +1896,8 @@ class LLVM_ABI_FOR_TEST VPHeaderPHIRecipe : public VPSingleDefRecipe,
18961896
~VPHeaderPHIRecipe() override = default;
18971897

18981898
/// Method to support type inquiry through isa, cast, and dyn_cast.
1899-
static inline bool classof(const VPRecipeBase *B) {
1899+
static inline bool classof(const VPUser *U) {
1900+
auto *B = cast<VPRecipeBase>(U);
19001901
return B->getVPDefID() >= VPDef::VPFirstHeaderPHISC &&
19011902
B->getVPDefID() <= VPDef::VPLastHeaderPHISC;
19021903
}
@@ -1905,6 +1906,10 @@ class LLVM_ABI_FOR_TEST VPHeaderPHIRecipe : public VPSingleDefRecipe,
19051906
return B && B->getVPDefID() >= VPRecipeBase::VPFirstHeaderPHISC &&
19061907
B->getVPDefID() <= VPRecipeBase::VPLastHeaderPHISC;
19071908
}
1909+
static inline bool classof(const VPSingleDefRecipe *B) {
1910+
return B->getVPDefID() >= VPDef::VPFirstHeaderPHISC &&
1911+
B->getVPDefID() <= VPDef::VPLastHeaderPHISC;
1912+
}
19081913

19091914
/// Generate the phi nodes.
19101915
void execute(VPTransformState &State) override = 0;
@@ -1966,7 +1971,7 @@ class VPWidenInductionRecipe : public VPHeaderPHIRecipe {
19661971
return R && classof(R);
19671972
}
19681973

1969-
static inline bool classof(const VPHeaderPHIRecipe *R) {
1974+
static inline bool classof(const VPSingleDefRecipe *R) {
19701975
return classof(static_cast<const VPRecipeBase *>(R));
19711976
}
19721977

llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,3 +815,148 @@ bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) {
815815
MiddleTerm->setOperand(0, NewCond);
816816
return true;
817817
}
818+
819+
bool VPlanTransforms::legalizeUnclassifiedPhis(VPlan &Plan) {
820+
using namespace VPlanPatternMatch;
821+
for (auto &PhiR : make_early_inc_range(
822+
Plan.getVectorLoopRegion()->getEntryBasicBlock()->phis())) {
823+
if (!isa<VPWidenPHIRecipe>(&PhiR))
824+
continue;
825+
826+
// Check if PhiR is a min/max reduction that has a user inside the loop
827+
// outside the min/max reduction chain. The other user must be the compare
828+
// of a FindLastIV reduction chain.
829+
auto *MinMaxPhiR = cast<VPWidenPHIRecipe>(&PhiR);
830+
auto *MinMaxOp = dyn_cast_or_null<VPSingleDefRecipe>(
831+
MinMaxPhiR->getOperand(1)->getDefiningRecipe());
832+
if (!MinMaxOp)
833+
return false;
834+
835+
// The incoming value must be a min/max instrinsic.
836+
// TODO: Also handle the select variant.
837+
Intrinsic::ID ID = Intrinsic::not_intrinsic;
838+
if (auto *WideInt = dyn_cast<VPWidenIntrinsicRecipe>(MinMaxOp))
839+
ID = WideInt->getVectorIntrinsicID();
840+
else {
841+
auto *RepR = dyn_cast<VPReplicateRecipe>(MinMaxOp);
842+
if (!RepR || !isa<IntrinsicInst>(RepR->getUnderlyingInstr()))
843+
return false;
844+
ID = cast<IntrinsicInst>(RepR->getUnderlyingInstr())->getIntrinsicID();
845+
}
846+
RecurKind RdxKind = RecurKind::None;
847+
switch (ID) {
848+
case Intrinsic::umax:
849+
RdxKind = RecurKind::UMax;
850+
break;
851+
case Intrinsic::umin:
852+
RdxKind = RecurKind::UMin;
853+
break;
854+
case Intrinsic::smax:
855+
RdxKind = RecurKind::SMax;
856+
break;
857+
case Intrinsic::smin:
858+
RdxKind = RecurKind::SMin;
859+
break;
860+
default:
861+
return false;
862+
}
863+
864+
// The min/max intrinsic must use the phi and itself must only be used by
865+
// the phi and a resume-phi in the scalar preheader.
866+
if (MinMaxOp->getOperand(0) != MinMaxPhiR &&
867+
MinMaxOp->getOperand(1) != MinMaxPhiR)
868+
return false;
869+
if (MinMaxPhiR->getNumUsers() != 2 ||
870+
any_of(MinMaxOp->users(), [MinMaxPhiR, &Plan](VPUser *U) {
871+
auto *Phi = dyn_cast<VPPhi>(U);
872+
return MinMaxPhiR != U &&
873+
(!Phi || Phi->getParent() != Plan.getScalarPreheader());
874+
}))
875+
return false;
876+
877+
// One user of MinMaxPhiR is MinMaxOp, the other users must be a compare
878+
// that's part of a FindLastIV chain.
879+
auto MinMaxUsers = to_vector(MinMaxPhiR->users());
880+
auto *Cmp = dyn_cast<VPRecipeWithIRFlags>(
881+
MinMaxUsers[0] == MinMaxOp ? MinMaxUsers[1] : MinMaxUsers[0]);
882+
VPValue *CmpOpA;
883+
VPValue *CmpOpB;
884+
if (!Cmp || Cmp->getNumUsers() != 1 ||
885+
!match(Cmp, m_Binary<Instruction::ICmp>(m_VPValue(CmpOpA),
886+
m_VPValue(CmpOpB))))
887+
return false;
888+
889+
// Normalize the predicate so MinMaxPhiR is on the right side.
890+
CmpInst::Predicate Pred = Cmp->getPredicate();
891+
if (CmpOpA == MinMaxPhiR)
892+
Pred = CmpInst::getSwappedPredicate(Pred);
893+
894+
// Determine if the predicate is not strict.
895+
bool IsNonStrictPred = ICmpInst::isLE(Pred) || ICmpInst::isGE(Pred);
896+
// Account for a mis-match between RdxKind and the predicate.
897+
switch (RdxKind) {
898+
case RecurKind::UMin:
899+
case RecurKind::SMin:
900+
IsNonStrictPred |= ICmpInst::isGT(Pred);
901+
break;
902+
case RecurKind::UMax:
903+
case RecurKind::SMax:
904+
IsNonStrictPred |= ICmpInst::isLT(Pred);
905+
break;
906+
default:
907+
llvm_unreachable("unsupported kind");
908+
}
909+
910+
// TODO: Strict predicates need to find the first IV value for which the
911+
// predicate holds, not the last.
912+
if (Pred == CmpInst::ICMP_NE || !IsNonStrictPred)
913+
return false;
914+
915+
// Cmp must be used by the select of a FindLastIV chain.
916+
VPValue *Sel = dyn_cast<VPSingleDefRecipe>(*Cmp->user_begin());
917+
VPValue *IVOp, *FindIV;
918+
if (!Sel ||
919+
!match(Sel,
920+
m_Select(m_Specific(Cmp), m_VPValue(IVOp), m_VPValue(FindIV))) ||
921+
Sel->getNumUsers() != 2 || !isa<VPWidenIntOrFpInductionRecipe>(IVOp))
922+
return false;
923+
auto *FindIVPhiR = dyn_cast<VPReductionPHIRecipe>(FindIV);
924+
if (!FindIVPhiR || !RecurrenceDescriptor::isFindLastIVRecurrenceKind(
925+
FindIVPhiR->getRecurrenceKind()))
926+
return false;
927+
928+
assert(!FindIVPhiR->isInLoop() && !FindIVPhiR->isOrdered() &&
929+
"cannot handle inloop/ordered reductions yet");
930+
931+
auto NewPhiR = new VPReductionPHIRecipe(
932+
cast<PHINode>(MinMaxPhiR->getUnderlyingInstr()), RdxKind,
933+
*MinMaxPhiR->getOperand(0), false, false, 1);
934+
NewPhiR->insertBefore(MinMaxPhiR);
935+
MinMaxPhiR->replaceAllUsesWith(NewPhiR);
936+
NewPhiR->addOperand(MinMaxPhiR->getOperand(1));
937+
MinMaxPhiR->eraseFromParent();
938+
939+
// The reduction using MinMaxPhiR needs adjusting to compute the correct
940+
// result:
941+
// 1. We need to find the last IV for which the condition based on the
942+
// min/max recurrence is true,
943+
// 2. Compare the partial min/max reduction result to its final value and,
944+
// 3. Select the lanes of the partial FindLastIV reductions which
945+
// correspond to the lanes matching the min/max reduction result.
946+
VPInstruction *FindIVResult = cast<VPInstruction>(
947+
*(Sel->user_begin() + (*Sel->user_begin() == FindIVPhiR ? 1 : 0)));
948+
VPBuilder B(FindIVResult);
949+
VPInstruction *MinMaxResult =
950+
B.createNaryOp(VPInstruction::ComputeReductionResult,
951+
{NewPhiR, NewPhiR->getBackedgeValue()}, VPIRFlags(), {});
952+
NewPhiR->getBackedgeValue()->replaceUsesWithIf(
953+
MinMaxResult, [](VPUser &U, unsigned) { return isa<VPPhi>(&U); });
954+
auto *FinalMinMaxCmp = B.createICmp(
955+
CmpInst::ICMP_EQ, MinMaxResult->getOperand(1), MinMaxResult);
956+
auto *FinalIVSelect =
957+
B.createSelect(FinalMinMaxCmp, FindIVResult->getOperand(3),
958+
FindIVResult->getOperand(2));
959+
FindIVResult->setOperand(3, FinalIVSelect);
960+
}
961+
return true;
962+
}

llvm/lib/Transforms/Vectorize/VPlanTransforms.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ struct VPlanTransforms {
9393
GetIntOrFpInductionDescriptor,
9494
ScalarEvolution &SE, const TargetLibraryInfo &TLI);
9595

96+
/// Try to legalize unclassified phis by converting VPWidenPHIRecipes to
97+
/// min-max reductions used by FindLastIV reductions if possible. Returns
98+
/// false if the VPlan contains VPWidenPHIRecipes that cannot be legalized.
99+
static bool legalizeUnclassifiedPhis(VPlan &Plan);
100+
96101
/// Try to have all users of fixed-order recurrences appear after the recipe
97102
/// defining their previous value, by either sinking users or hoisting recipes
98103
/// defining their previous value (and its operands). Then introduce

0 commit comments

Comments
 (0)