-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[LV] Vectorize selecting last IV of min/max element. #141431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -815,3 +815,148 @@ bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) { | |||||
MiddleTerm->setOperand(0, NewCond); | ||||||
return true; | ||||||
} | ||||||
|
||||||
bool VPlanTransforms::legalizeUnclassifiedPhis(VPlan &Plan) { | ||||||
using namespace VPlanPatternMatch; | ||||||
for (auto &PhiR : make_early_inc_range( | ||||||
Plan.getVectorLoopRegion()->getEntryBasicBlock()->phis())) { | ||||||
if (!isa<VPWidenPHIRecipe>(&PhiR)) | ||||||
continue; | ||||||
|
||||||
// Check if PhiR is a min/max reduction that has a user inside the loop | ||||||
// outside the min/max reduction chain. The other user must be the compare | ||||||
// of a FindLastIV reduction chain. | ||||||
auto *MinMaxPhiR = cast<VPWidenPHIRecipe>(&PhiR); | ||||||
auto *MinMaxOp = dyn_cast_or_null<VPSingleDefRecipe>( | ||||||
MinMaxPhiR->getOperand(1)->getDefiningRecipe()); | ||||||
if (!MinMaxOp) | ||||||
return false; | ||||||
|
||||||
// The incoming value must be a min/max instrinsic. | ||||||
// TODO: Also handle the select variant. | ||||||
Intrinsic::ID ID = Intrinsic::not_intrinsic; | ||||||
if (auto *WideInt = dyn_cast<VPWidenIntrinsicRecipe>(MinMaxOp)) { | ||||||
ID = WideInt->getVectorIntrinsicID(); | ||||||
} else { | ||||||
auto *RepR = dyn_cast<VPReplicateRecipe>(MinMaxOp); | ||||||
if (!RepR || !isa<IntrinsicInst>(RepR->getUnderlyingInstr())) | ||||||
return false; | ||||||
ID = cast<IntrinsicInst>(RepR->getUnderlyingInstr())->getIntrinsicID(); | ||||||
} | ||||||
RecurKind RdxKind = RecurKind::None; | ||||||
switch (ID) { | ||||||
case Intrinsic::umax: | ||||||
RdxKind = RecurKind::UMax; | ||||||
break; | ||||||
case Intrinsic::umin: | ||||||
RdxKind = RecurKind::UMin; | ||||||
break; | ||||||
case Intrinsic::smax: | ||||||
RdxKind = RecurKind::SMax; | ||||||
break; | ||||||
case Intrinsic::smin: | ||||||
RdxKind = RecurKind::SMin; | ||||||
break; | ||||||
default: | ||||||
return false; | ||||||
} | ||||||
|
||||||
// The min/max intrinsic must use the phi and itself must only be used by | ||||||
// the phi and a resume-phi in the scalar preheader. | ||||||
if (MinMaxOp->getOperand(0) != MinMaxPhiR && | ||||||
MinMaxOp->getOperand(1) != MinMaxPhiR) | ||||||
return false; | ||||||
if (MinMaxPhiR->getNumUsers() != 2 || | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move the |
||||||
any_of(MinMaxOp->users(), [MinMaxPhiR, &Plan](VPUser *U) { | ||||||
auto *Phi = dyn_cast<VPPhi>(U); | ||||||
return MinMaxPhiR != U && | ||||||
(!Phi || Phi->getParent() != Plan.getScalarPreheader()); | ||||||
})) | ||||||
return false; | ||||||
|
||||||
// One user of MinMaxPhiR is MinMaxOp, the other users must be a compare | ||||||
// that's part of a FindLastIV chain. | ||||||
auto MinMaxUsers = to_vector(MinMaxPhiR->users()); | ||||||
auto *Cmp = dyn_cast<VPRecipeWithIRFlags>( | ||||||
MinMaxUsers[0] == MinMaxOp ? MinMaxUsers[1] : MinMaxUsers[0]); | ||||||
Comment on lines
+879
to
+881
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should check the number of phi's users, like #141467 did:
|
||||||
VPValue *CmpOpA; | ||||||
VPValue *CmpOpB; | ||||||
Comment on lines
+882
to
+883
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. compress into one line, and choose the better name. |
||||||
if (!Cmp || Cmp->getNumUsers() != 1 || | ||||||
!match(Cmp, m_Binary<Instruction::ICmp>(m_VPValue(CmpOpA), | ||||||
m_VPValue(CmpOpB)))) | ||||||
return false; | ||||||
|
||||||
// Normalize the predicate so MinMaxPhiR is on the right side. | ||||||
CmpInst::Predicate Pred = Cmp->getPredicate(); | ||||||
if (CmpOpA == MinMaxPhiR) | ||||||
Pred = CmpInst::getSwappedPredicate(Pred); | ||||||
|
||||||
// Determine if the predicate is not strict. | ||||||
bool IsNonStrictPred = ICmpInst::isLE(Pred) || ICmpInst::isGE(Pred); | ||||||
// Account for a mis-match between RdxKind and the predicate. | ||||||
switch (RdxKind) { | ||||||
case RecurKind::UMin: | ||||||
case RecurKind::SMin: | ||||||
IsNonStrictPred |= ICmpInst::isGT(Pred); | ||||||
break; | ||||||
case RecurKind::UMax: | ||||||
case RecurKind::SMax: | ||||||
IsNonStrictPred |= ICmpInst::isLT(Pred); | ||||||
break; | ||||||
default: | ||||||
llvm_unreachable("unsupported kind"); | ||||||
} | ||||||
|
||||||
// TODO: Strict predicates need to find the first IV value for which the | ||||||
// predicate holds, not the last. | ||||||
if (Pred == CmpInst::ICMP_NE || !IsNonStrictPred) | ||||||
return false; | ||||||
|
||||||
// Cmp must be used by the select of a FindLastIV chain. | ||||||
VPValue *Sel = dyn_cast<VPSingleDefRecipe>(*Cmp->user_begin()); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
VPValue *IVOp, *FindIV; | ||||||
if (!Sel || | ||||||
!match(Sel, | ||||||
m_Select(m_Specific(Cmp), m_VPValue(IVOp), m_VPValue(FindIV))) || | ||||||
Sel->getNumUsers() != 2 || !isa<VPWidenIntOrFpInductionRecipe>(IVOp)) | ||||||
return false; | ||||||
auto *FindIVPhiR = dyn_cast<VPReductionPHIRecipe>(FindIV); | ||||||
if (!FindIVPhiR || !RecurrenceDescriptor::isFindLastIVRecurrenceKind( | ||||||
FindIVPhiR->getRecurrenceKind())) | ||||||
return false; | ||||||
|
||||||
assert(!FindIVPhiR->isInLoop() && !FindIVPhiR->isOrdered() && | ||||||
"cannot handle inloop/ordered reductions yet"); | ||||||
|
||||||
auto NewPhiR = new VPReductionPHIRecipe( | ||||||
cast<PHINode>(MinMaxPhiR->getUnderlyingInstr()), RdxKind, | ||||||
*MinMaxPhiR->getOperand(0), false, false, 1); | ||||||
NewPhiR->insertBefore(MinMaxPhiR); | ||||||
MinMaxPhiR->replaceAllUsesWith(NewPhiR); | ||||||
NewPhiR->addOperand(MinMaxPhiR->getOperand(1)); | ||||||
MinMaxPhiR->eraseFromParent(); | ||||||
|
||||||
// The reduction using MinMaxPhiR needs adjusting to compute the correct | ||||||
// result: | ||||||
// 1. We need to find the last IV for which the condition based on the | ||||||
// min/max recurrence is true, | ||||||
// 2. Compare the partial min/max reduction result to its final value and, | ||||||
// 3. Select the lanes of the partial FindLastIV reductions which | ||||||
// correspond to the lanes matching the min/max reduction result. | ||||||
VPInstruction *FindIVResult = cast<VPInstruction>( | ||||||
*(Sel->user_begin() + (*Sel->user_begin() == FindIVPhiR ? 1 : 0))); | ||||||
VPBuilder B(FindIVResult); | ||||||
VPInstruction *MinMaxResult = | ||||||
B.createNaryOp(VPInstruction::ComputeReductionResult, | ||||||
{NewPhiR, NewPhiR->getBackedgeValue()}, VPIRFlags(), {}); | ||||||
NewPhiR->getBackedgeValue()->replaceUsesWithIf( | ||||||
MinMaxResult, [](VPUser &U, unsigned) { return isa<VPPhi>(&U); }); | ||||||
auto *FinalMinMaxCmp = B.createICmp( | ||||||
CmpInst::ICMP_EQ, MinMaxResult->getOperand(1), MinMaxResult); | ||||||
auto *FinalIVSelect = | ||||||
B.createSelect(FinalMinMaxCmp, FindIVResult->getOperand(3), | ||||||
FindIVResult->getOperand(2)); | ||||||
FindIVResult->setOperand(3, FinalIVSelect); | ||||||
} | ||||||
return true; | ||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MinMaxPhiR->getIncomingValue(1)