Skip to content

Commit ede1a96

Browse files
authored
[LV] Vectorize early exit loops with multiple exits. (#174864)
Building on top of the recent changes to introduce BranchOnTwoConds, this patch adds support for vectorizing loops with multiple early exits, all dominating a countable latch. The early exits must form a dominance chain, so we can simply check which early exit has been taken in dominance order. Currently LoopVectorizationLegality ensures that all exits other than the latch must be uncountable. handleUncountableEarlyExits now collects those uncountable exits and processes each exit. In the vector region, we compute if any exit has been taken, by taking the OR of all early exit conditions (EarlyExitConds) and checking if there's any active lane. If the early exit is taken, we exit the loop and compute which early exit has been taken. The first taken early exit is the one where its exit condition is true in the first active lane of EarlyExitConds. We create a chain of dispatch blocks outside the loop to check this for the early exit blocks ordered by dominance. Depends on llvm/llvm-project#174016. PR: llvm/llvm-project#174864
1 parent 8c162b7 commit ede1a96

14 files changed

+987
-410
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,12 @@ class VPBuilder {
233233
return createNaryOp(VPInstruction::LogicalAnd, {LHS, RHS}, DL, Name);
234234
}
235235

236+
VPInstruction *createLogicalOr(VPValue *LHS, VPValue *RHS,
237+
DebugLoc DL = DebugLoc::getUnknown(),
238+
const Twine &Name = "") {
239+
return createNaryOp(VPInstruction::LogicalOr, {LHS, RHS}, DL, Name);
240+
}
241+
236242
VPInstruction *createSelect(VPValue *Cond, VPValue *TrueVal,
237243
VPValue *FalseVal,
238244
DebugLoc DL = DebugLoc::getUnknown(),

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9509,15 +9509,6 @@ bool LoopVectorizePass::processLoop(Loop *L) {
95099509
"UncountableEarlyExitLoopsDisabled", ORE, L);
95109510
return false;
95119511
}
9512-
SmallVector<BasicBlock *, 8> ExitingBlocks;
9513-
L->getExitingBlocks(ExitingBlocks);
9514-
// TODO: Support multiple uncountable early exits.
9515-
if (ExitingBlocks.size() - LVL.getCountableExitingBlocks().size() > 1) {
9516-
reportVectorizationFailure("Auto-vectorization of loops with multiple "
9517-
"uncountable early exits is not yet supported",
9518-
"MultipleUncountableEarlyExits", ORE, L);
9519-
return false;
9520-
}
95219512
}
95229513

95239514
if (!LVL.getPotentiallyFaultingLoads().empty()) {

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,6 +1211,7 @@ class LLVM_ABI_FOR_TEST VPInstruction : public VPRecipeWithIRFlags,
12111211
// during unrolling.
12121212
ExtractPenultimateElement,
12131213
LogicalAnd, // Non-poison propagating logical And.
1214+
LogicalOr, // Non-poison propagating logical Or.
12141215
// Add an offset in bytes (second operand) to a base pointer (first
12151216
// operand). Only generates scalar values (either for the first lane only or
12161217
// for all lanes, depending on its uses).
@@ -1520,6 +1521,9 @@ class VPPhiAccessors {
15201521
/// Returns the incoming block with index \p Idx.
15211522
const VPBasicBlock *getIncomingBlock(unsigned Idx) const;
15221523

1524+
/// Returns the incoming value for \p VPBB. \p VPBB must be an incoming block.
1525+
VPValue *getIncomingValueForBlock(const VPBasicBlock *VPBB) const;
1526+
15231527
/// Returns the number of incoming values, also number of incoming blocks.
15241528
virtual unsigned getNumIncoming() const {
15251529
return getAsRecipe()->getNumOperands();

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,10 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
124124
case VPInstruction::LastActiveLane:
125125
return Type::getIntNTy(Ctx, 64);
126126
case VPInstruction::LogicalAnd:
127+
case VPInstruction::LogicalOr:
127128
assert(inferScalarType(R->getOperand(0))->isIntegerTy(1) &&
128129
inferScalarType(R->getOperand(1))->isIntegerTy(1) &&
129-
"LogicalAnd operands should be bool");
130+
"LogicalAnd/Or operands should be bool");
130131
return IntegerType::get(Ctx, 1);
131132
case VPInstruction::BranchOnCond:
132133
case VPInstruction::BranchOnTwoConds:

llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -903,33 +903,27 @@ void VPlanTransforms::handleEarlyExits(VPlan &Plan,
903903
auto *LatchVPBB = cast<VPBasicBlock>(MiddleVPBB->getSinglePredecessor());
904904
VPBlockBase *HeaderVPB = cast<VPBasicBlock>(LatchVPBB->getSuccessors()[1]);
905905

906-
// Disconnect all early exits from the loop leaving it with a single exit from
907-
// the latch. Early exits that are countable are left for a scalar epilog. The
908-
// condition of uncountable early exits (currently at most one is supported)
909-
// is fused into the latch exit, and used to branch from middle block to the
910-
// early exit destination.
911-
[[maybe_unused]] bool HandledUncountableEarlyExit = false;
906+
if (HasUncountableEarlyExit) {
907+
handleUncountableEarlyExits(Plan, cast<VPBasicBlock>(HeaderVPB), LatchVPBB,
908+
MiddleVPBB);
909+
return;
910+
}
911+
912+
// Disconnect countable early exits from the loop, leaving it with a single
913+
// exit from the latch. Countable early exits are left for a scalar epilog.
912914
for (VPIRBasicBlock *EB : Plan.getExitBlocks()) {
913915
for (VPBlockBase *Pred : to_vector(EB->getPredecessors())) {
914916
if (Pred == MiddleVPBB)
915917
continue;
916-
if (HasUncountableEarlyExit) {
917-
assert(!HandledUncountableEarlyExit &&
918-
"can handle exactly one uncountable early exit");
919-
handleUncountableEarlyExit(cast<VPBasicBlock>(Pred), EB, Plan,
920-
cast<VPBasicBlock>(HeaderVPB), LatchVPBB);
921-
HandledUncountableEarlyExit = true;
922-
} else {
923-
for (VPRecipeBase &R : EB->phis())
924-
cast<VPIRPhi>(&R)->removeIncomingValueFor(Pred);
925-
}
926-
cast<VPBasicBlock>(Pred)->getTerminator()->eraseFromParent();
918+
919+
// Remove phi operands for the early exiting block.
920+
for (VPRecipeBase &R : EB->phis())
921+
cast<VPIRPhi>(&R)->removeIncomingValueFor(Pred);
922+
auto *EarlyExitingVPBB = cast<VPBasicBlock>(Pred);
923+
EarlyExitingVPBB->getTerminator()->eraseFromParent();
927924
VPBlockUtils::disconnectBlocks(Pred, EB);
928925
}
929926
}
930-
931-
assert((!HasUncountableEarlyExit || HandledUncountableEarlyExit) &&
932-
"missed an uncountable exit that must be handled");
933927
}
934928

935929
void VPlanTransforms::addMiddleCheck(VPlan &Plan,

llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -833,9 +833,10 @@ inline auto m_c_LogicalAnd(const Op0_t &Op0, const Op1_t &Op1) {
833833
}
834834

835835
template <typename Op0_t, typename Op1_t>
836-
inline AllRecipe_match<Instruction::Select, Op0_t, specific_intval<1>, Op1_t>
837-
m_LogicalOr(const Op0_t &Op0, const Op1_t &Op1) {
838-
return m_Select(Op0, m_True(), Op1);
836+
inline auto m_LogicalOr(const Op0_t &Op0, const Op1_t &Op1) {
837+
return m_CombineOr(
838+
m_c_VPInstruction<VPInstruction::LogicalOr, Op0_t, Op1_t>(Op0, Op1),
839+
m_Select(Op0, m_True(), Op1));
839840
}
840841

841842
template <typename Op0_t, typename Op1_t>

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,7 @@ unsigned VPInstruction::getNumOperandsForOpcode() const {
468468
case VPInstruction::ExitingIVValue:
469469
case VPInstruction::FirstOrderRecurrenceSplice:
470470
case VPInstruction::LogicalAnd:
471+
case VPInstruction::LogicalOr:
471472
case VPInstruction::PtrAdd:
472473
case VPInstruction::WidePtrAdd:
473474
case VPInstruction::WideIVStep:
@@ -813,6 +814,11 @@ Value *VPInstruction::generate(VPTransformState &State) {
813814
Value *B = State.get(getOperand(1));
814815
return Builder.CreateLogicalAnd(A, B, Name);
815816
}
817+
case VPInstruction::LogicalOr: {
818+
Value *A = State.get(getOperand(0));
819+
Value *B = State.get(getOperand(1));
820+
return Builder.CreateLogicalOr(A, B, Name);
821+
}
816822
case VPInstruction::PtrAdd: {
817823
assert((State.VF.isScalar() || vputils::onlyFirstLaneUsed(this)) &&
818824
"can only generate first lane for PtrAdd");
@@ -1339,6 +1345,7 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
13391345
case VPInstruction::ExtractLastActive:
13401346
case VPInstruction::FirstOrderRecurrenceSplice:
13411347
case VPInstruction::LogicalAnd:
1348+
case VPInstruction::LogicalOr:
13421349
case VPInstruction::Not:
13431350
case VPInstruction::PtrAdd:
13441351
case VPInstruction::WideIVStep:
@@ -1506,6 +1513,9 @@ void VPInstruction::printRecipe(raw_ostream &O, const Twine &Indent,
15061513
case VPInstruction::LogicalAnd:
15071514
O << "logical-and";
15081515
break;
1516+
case VPInstruction::LogicalOr:
1517+
O << "logical-or";
1518+
break;
15091519
case VPInstruction::PtrAdd:
15101520
O << "ptradd";
15111521
break;
@@ -1705,6 +1715,14 @@ void VPPhiAccessors::removeIncomingValueFor(VPBlockBase *IncomingBlock) const {
17051715
R->removeOperand(Position);
17061716
}
17071717

1718+
VPValue *
1719+
VPPhiAccessors::getIncomingValueForBlock(const VPBasicBlock *VPBB) const {
1720+
for (unsigned Idx = 0; Idx != getNumIncoming(); ++Idx)
1721+
if (getIncomingBlock(Idx) == VPBB)
1722+
return getIncomingValue(Idx);
1723+
llvm_unreachable("VPBB is not an incoming block");
1724+
}
1725+
17081726
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
17091727
void VPPhiAccessors::printPhiOperands(raw_ostream &O,
17101728
VPSlotTracker &SlotTracker) const {

0 commit comments

Comments
 (0)