Skip to content

Commit 5d3f690

Browse files
committed
Reapply "[LV] Use ExtractLane(LastActiveLane, V) live outs when tail-folding. (llvm#149042)"
This reverts commit a6edeed.
1 parent 318e7df commit 5d3f690

22 files changed

+1673
-605
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2095,24 +2095,6 @@ bool LoopVectorizationLegality::canFoldTailByMasking() const {
20952095
for (const auto &Reduction : getReductionVars())
20962096
ReductionLiveOuts.insert(Reduction.second.getLoopExitInstr());
20972097

2098-
// TODO: handle non-reduction outside users when tail is folded by masking.
2099-
for (auto *AE : AllowedExit) {
2100-
// Check that all users of allowed exit values are inside the loop or
2101-
// are the live-out of a reduction.
2102-
if (ReductionLiveOuts.count(AE))
2103-
continue;
2104-
for (User *U : AE->users()) {
2105-
Instruction *UI = cast<Instruction>(U);
2106-
if (TheLoop->contains(UI))
2107-
continue;
2108-
LLVM_DEBUG(
2109-
dbgs()
2110-
<< "LV: Cannot fold tail by masking, loop has an outside user for "
2111-
<< *UI << "\n");
2112-
return false;
2113-
}
2114-
}
2115-
21162098
for (const auto &Entry : getInductionVars()) {
21172099
PHINode *OrigPhi = Entry.first;
21182100
for (User *U : OrigPhi->users()) {

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8941,7 +8941,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
89418941
if (FinalReductionResult == U || Parent->getParent())
89428942
continue;
89438943
U->replaceUsesOfWith(OrigExitingVPV, FinalReductionResult);
8944-
if (match(U, m_ExtractLastElement(m_VPValue())))
8944+
if (match(U, m_CombineOr(m_ExtractLastElement(m_VPValue()),
8945+
m_ExtractLane(m_VPValue(), m_VPValue()))))
89458946
cast<VPInstruction>(U)->replaceAllUsesWith(FinalReductionResult);
89468947
}
89478948

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,6 +1097,13 @@ class LLVM_ABI_FOR_TEST VPInstruction : public VPRecipeWithIRFlags,
10971097
// It produces the lane index across all unrolled iterations. Unrolling will
10981098
// add all copies of its original operand as additional operands.
10991099
FirstActiveLane,
1100+
// Calculates the last active lane index of the vector predicate operands.
1101+
// The predicates must be prefix-masks (all 1s before all 0s). Used when
1102+
// tail-folding to extract the correct live-out value from the last active
1103+
// iteration. It produces the lane index across all unrolled iterations.
1104+
// Unrolling will add all copies of its original operand as additional
1105+
// operands.
1106+
LastActiveLane,
11001107

11011108
// The opcodes below are used for VPInstructionWithType.
11021109
//

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
115115
case VPInstruction::ExtractLane:
116116
return inferScalarType(R->getOperand(1));
117117
case VPInstruction::FirstActiveLane:
118+
case VPInstruction::LastActiveLane:
118119
return Type::getIntNTy(Ctx, 64);
119120
case VPInstruction::ExtractLastElement:
120121
case VPInstruction::ExtractLastLanePerPart:

llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,12 +398,24 @@ m_ExtractElement(const Op0_t &Op0, const Op1_t &Op1) {
398398
return m_VPInstruction<Instruction::ExtractElement>(Op0, Op1);
399399
}
400400

401+
template <typename Op0_t, typename Op1_t>
402+
inline VPInstruction_match<VPInstruction::ExtractLane, Op0_t, Op1_t>
403+
m_ExtractLane(const Op0_t &Op0, const Op1_t &Op1) {
404+
return m_VPInstruction<VPInstruction::ExtractLane>(Op0, Op1);
405+
}
406+
401407
template <typename Op0_t>
402408
inline VPInstruction_match<VPInstruction::ExtractLastLanePerPart, Op0_t>
403409
m_ExtractLastLanePerPart(const Op0_t &Op0) {
404410
return m_VPInstruction<VPInstruction::ExtractLastLanePerPart>(Op0);
405411
}
406412

413+
template <typename Op0_t>
414+
inline VPInstruction_match<VPInstruction::ExtractPenultimateElement, Op0_t>
415+
m_ExtractPenultimateElement(const Op0_t &Op0) {
416+
return m_VPInstruction<VPInstruction::ExtractPenultimateElement>(Op0);
417+
}
418+
407419
template <typename Op0_t, typename Op1_t, typename Op2_t>
408420
inline VPInstruction_match<VPInstruction::ActiveLaneMask, Op0_t, Op1_t, Op2_t>
409421
m_ActiveLaneMask(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
@@ -432,6 +444,16 @@ m_FirstActiveLane(const Op0_t &Op0) {
432444
return m_VPInstruction<VPInstruction::FirstActiveLane>(Op0);
433445
}
434446

447+
template <typename Op0_t>
448+
inline VPInstruction_match<VPInstruction::LastActiveLane, Op0_t>
449+
m_LastActiveLane(const Op0_t &Op0) {
450+
return m_VPInstruction<VPInstruction::LastActiveLane>(Op0);
451+
}
452+
453+
inline VPInstruction_match<VPInstruction::StepVector> m_StepVector() {
454+
return m_VPInstruction<VPInstruction::StepVector>();
455+
}
456+
435457
template <unsigned Opcode, typename Op0_t>
436458
inline AllRecipe_match<Opcode, Op0_t> m_Unary(const Op0_t &Op0) {
437459
return AllRecipe_match<Opcode, Op0_t>(Op0);

llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,6 @@ class VPPredicator {
4444
/// possibly inserting new recipes at \p Dst (using Builder's insertion point)
4545
VPValue *createEdgeMask(VPBasicBlock *Src, VPBasicBlock *Dst);
4646

47-
/// Returns the *entry* mask for \p VPBB.
48-
VPValue *getBlockInMask(VPBasicBlock *VPBB) const {
49-
return BlockMaskCache.lookup(VPBB);
50-
}
51-
5247
/// Record \p Mask as the *entry* mask of \p VPBB, which is expected to not
5348
/// already have a mask.
5449
void setBlockInMask(VPBasicBlock *VPBB, VPValue *Mask) {
@@ -68,6 +63,11 @@ class VPPredicator {
6863
}
6964

7065
public:
66+
/// Returns the *entry* mask for \p VPBB.
67+
VPValue *getBlockInMask(VPBasicBlock *VPBB) const {
68+
return BlockMaskCache.lookup(VPBB);
69+
}
70+
7171
/// Returns the precomputed predicate of the edge from \p Src to \p Dst.
7272
VPValue *getEdgeMask(const VPBasicBlock *Src, const VPBasicBlock *Dst) const {
7373
return EdgeMaskCache.lookup({Src, Dst});
@@ -301,5 +301,34 @@ VPlanTransforms::introduceMasksAndLinearize(VPlan &Plan, bool FoldTail) {
301301

302302
PrevVPBB = VPBB;
303303
}
304+
305+
// If we folded the tail and introduced a header mask, any extract of the
306+
// last element must be updated to extract from the last active lane of the
307+
// header mask instead (i.e., the lane corresponding to the last active
308+
// iteration).
309+
if (FoldTail) {
310+
assert(Plan.getExitBlocks().size() == 1 &&
311+
"only a single-exit block is supported currently");
312+
VPBasicBlock *EB = Plan.getExitBlocks().front();
313+
assert(EB->getSinglePredecessor() == Plan.getMiddleBlock() &&
314+
"the exit block must have middle block as single predecessor");
315+
316+
VPBuilder B(Plan.getMiddleBlock()->getTerminator());
317+
for (auto &P : EB->phis()) {
318+
auto *ExitIRI = cast<VPIRPhi>(&P);
319+
VPValue *Inc = ExitIRI->getIncomingValue(0);
320+
VPValue *Op;
321+
if (!match(Inc, m_ExtractLastElement(m_VPValue(Op))))
322+
continue;
323+
324+
// Compute the index of the last active lane.
325+
VPValue *HeaderMask = Predicator.getBlockInMask(Header);
326+
VPValue *LastActiveLane =
327+
B.createNaryOp(VPInstruction::LastActiveLane, HeaderMask);
328+
auto *Ext =
329+
B.createNaryOp(VPInstruction::ExtractLane, {LastActiveLane, Op});
330+
Inc->replaceAllUsesWith(Ext);
331+
}
332+
}
304333
return Predicator.getBlockMaskCache();
305334
}

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,6 @@ unsigned VPInstruction::getNumOperandsForOpcode(unsigned Opcode) {
560560
case VPInstruction::ExtractLastElement:
561561
case VPInstruction::ExtractLastLanePerPart:
562562
case VPInstruction::ExtractPenultimateElement:
563-
case VPInstruction::FirstActiveLane:
564563
case VPInstruction::Not:
565564
case VPInstruction::ResumeForEpilogue:
566565
case VPInstruction::Unpack:
@@ -589,6 +588,8 @@ unsigned VPInstruction::getNumOperandsForOpcode(unsigned Opcode) {
589588
case Instruction::GetElementPtr:
590589
case Instruction::PHI:
591590
case Instruction::Switch:
591+
case VPInstruction::FirstActiveLane:
592+
case VPInstruction::LastActiveLane:
592593
case VPInstruction::SLPLoad:
593594
case VPInstruction::SLPStore:
594595
// Cannot determine the number of operands from the opcode.
@@ -1174,6 +1175,29 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
11741175
{PredTy, Type::getInt1Ty(Ctx.LLVMCtx)});
11751176
return Ctx.TTI.getIntrinsicInstrCost(Attrs, Ctx.CostKind);
11761177
}
1178+
case VPInstruction::LastActiveLane: {
1179+
Type *ScalarTy = Ctx.Types.inferScalarType(getOperand(0));
1180+
if (VF.isScalar())
1181+
return Ctx.TTI.getCmpSelInstrCost(Instruction::ICmp, ScalarTy,
1182+
CmpInst::makeCmpResultType(ScalarTy),
1183+
CmpInst::ICMP_EQ, Ctx.CostKind);
1184+
// Calculate the cost of determining the lane index: NOT + cttz_elts + SUB.
1185+
auto *PredTy = toVectorTy(ScalarTy, VF);
1186+
IntrinsicCostAttributes Attrs(Intrinsic::experimental_cttz_elts,
1187+
Type::getInt64Ty(Ctx.LLVMCtx),
1188+
{PredTy, Type::getInt1Ty(Ctx.LLVMCtx)});
1189+
InstructionCost Cost = Ctx.TTI.getIntrinsicInstrCost(Attrs, Ctx.CostKind);
1190+
// Add cost of NOT operation on the predicate.
1191+
Cost += Ctx.TTI.getArithmeticInstrCost(
1192+
Instruction::Xor, PredTy, Ctx.CostKind,
1193+
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
1194+
{TargetTransformInfo::OK_UniformConstantValue,
1195+
TargetTransformInfo::OP_None});
1196+
// Add cost of SUB operation on the index.
1197+
Cost += Ctx.TTI.getArithmeticInstrCost(
1198+
Instruction::Sub, Type::getInt64Ty(Ctx.LLVMCtx), Ctx.CostKind);
1199+
return Cost;
1200+
}
11771201
case VPInstruction::FirstOrderRecurrenceSplice: {
11781202
assert(VF.isVector() && "Scalar FirstOrderRecurrenceSplice?");
11791203
SmallVector<int> Mask(VF.getKnownMinValue());
@@ -1228,6 +1252,7 @@ bool VPInstruction::isVectorToScalar() const {
12281252
getOpcode() == Instruction::ExtractElement ||
12291253
getOpcode() == VPInstruction::ExtractLane ||
12301254
getOpcode() == VPInstruction::FirstActiveLane ||
1255+
getOpcode() == VPInstruction::LastActiveLane ||
12311256
getOpcode() == VPInstruction::ComputeAnyOfResult ||
12321257
getOpcode() == VPInstruction::ComputeFindIVResult ||
12331258
getOpcode() == VPInstruction::ComputeReductionResult ||
@@ -1294,6 +1319,7 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
12941319
case VPInstruction::ActiveLaneMask:
12951320
case VPInstruction::ExplicitVectorLength:
12961321
case VPInstruction::FirstActiveLane:
1322+
case VPInstruction::LastActiveLane:
12971323
case VPInstruction::FirstOrderRecurrenceSplice:
12981324
case VPInstruction::LogicalAnd:
12991325
case VPInstruction::Not:
@@ -1470,6 +1496,9 @@ void VPInstruction::printRecipe(raw_ostream &O, const Twine &Indent,
14701496
case VPInstruction::FirstActiveLane:
14711497
O << "first-active-lane";
14721498
break;
1499+
case VPInstruction::LastActiveLane:
1500+
O << "last-active-lane";
1501+
break;
14731502
case VPInstruction::ReductionStartVector:
14741503
O << "reduction-start-vector";
14751504
break;

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -826,8 +826,8 @@ static VPValue *optimizeEarlyExitInductionUser(VPlan &Plan,
826826
VPValue *Op,
827827
ScalarEvolution &SE) {
828828
VPValue *Incoming, *Mask;
829-
if (!match(Op, m_VPInstruction<VPInstruction::ExtractLane>(
830-
m_FirstActiveLane(m_VPValue(Mask)), m_VPValue(Incoming))))
829+
if (!match(Op, m_ExtractLane(m_FirstActiveLane(m_VPValue(Mask)),
830+
m_VPValue(Incoming))))
831831
return nullptr;
832832

833833
auto *WideIV = getOptimizableIVOf(Incoming, SE);
@@ -1295,8 +1295,7 @@ static void simplifyRecipe(VPSingleDefRecipe *Def, VPTypeAnalysis &TypeInfo) {
12951295
}
12961296

12971297
// Look through ExtractPenultimateElement (BuildVector ....).
1298-
if (match(Def, m_VPInstruction<VPInstruction::ExtractPenultimateElement>(
1299-
m_BuildVector()))) {
1298+
if (match(Def, m_ExtractPenultimateElement(m_BuildVector()))) {
13001299
auto *BuildVector = cast<VPInstruction>(Def->getOperand(0));
13011300
Def->replaceAllUsesWith(
13021301
BuildVector->getOperand(BuildVector->getNumOperands() - 2));
@@ -2106,6 +2105,32 @@ bool VPlanTransforms::adjustFixedOrderRecurrences(VPlan &Plan,
21062105
// Set the first operand of RecurSplice to FOR again, after replacing
21072106
// all users.
21082107
RecurSplice->setOperand(0, FOR);
2108+
2109+
// Check for users extracting at the penultimate active lane of the FOR.
2110+
// If only a single lane is active in the current iteration, we need to
2111+
// select the last element from the previous iteration (from the FOR phi
2112+
// directly).
2113+
for (VPUser *U : RecurSplice->users()) {
2114+
if (!match(U, m_ExtractLane(m_LastActiveLane(m_VPValue()),
2115+
m_Specific(RecurSplice))))
2116+
continue;
2117+
2118+
VPBuilder B(cast<VPInstruction>(U));
2119+
VPValue *LastActiveLane = cast<VPInstruction>(U)->getOperand(0);
2120+
Type *I64Ty = Type::getInt64Ty(Plan.getContext());
2121+
VPValue *Zero = Plan.getOrAddLiveIn(ConstantInt::get(I64Ty, 0));
2122+
VPValue *One = Plan.getOrAddLiveIn(ConstantInt::get(I64Ty, 1));
2123+
VPValue *PenultimateIndex =
2124+
B.createNaryOp(Instruction::Sub, {LastActiveLane, One});
2125+
VPValue *PenultimateLastIter =
2126+
B.createNaryOp(VPInstruction::ExtractLane,
2127+
{PenultimateIndex, FOR->getBackedgeValue()});
2128+
VPValue *LastPrevIter =
2129+
B.createNaryOp(VPInstruction::ExtractLastElement, FOR);
2130+
VPValue *Cmp = B.createICmp(CmpInst::ICMP_EQ, LastActiveLane, Zero);
2131+
VPValue *Sel = B.createSelect(Cmp, LastPrevIter, PenultimateLastIter);
2132+
cast<VPInstruction>(U)->replaceAllUsesWith(Sel);
2133+
}
21092134
}
21102135
return true;
21112136
}
@@ -3492,6 +3517,34 @@ void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan) {
34923517
ToRemove.push_back(Expr);
34933518
}
34943519

3520+
// Expand LastActiveLane into Not + FirstActiveLane + Sub.
3521+
auto *LastActiveL = dyn_cast<VPInstruction>(&R);
3522+
if (LastActiveL &&
3523+
LastActiveL->getOpcode() == VPInstruction::LastActiveLane) {
3524+
// Create Not(Mask) for all operands.
3525+
SmallVector<VPValue *, 2> NotMasks;
3526+
for (VPValue *Op : LastActiveL->operands()) {
3527+
VPValue *NotMask = Builder.createNot(Op, LastActiveL->getDebugLoc());
3528+
NotMasks.push_back(NotMask);
3529+
}
3530+
3531+
// Create FirstActiveLane on the inverted masks.
3532+
VPValue *FirstInactiveLane = Builder.createNaryOp(
3533+
VPInstruction::FirstActiveLane, NotMasks,
3534+
LastActiveL->getDebugLoc(), "first.inactive.lane");
3535+
3536+
// Subtract 1 to get the last active lane.
3537+
VPValue *One = Plan.getOrAddLiveIn(
3538+
ConstantInt::get(Type::getInt64Ty(Plan.getContext()), 1));
3539+
VPValue *LastLane = Builder.createNaryOp(
3540+
Instruction::Sub, {FirstInactiveLane, One},
3541+
LastActiveL->getDebugLoc(), "last.active.lane");
3542+
3543+
LastActiveL->replaceAllUsesWith(LastLane);
3544+
ToRemove.push_back(LastActiveL);
3545+
continue;
3546+
}
3547+
34953548
VPValue *VectorStep;
34963549
VPValue *ScalarStep;
34973550
if (!match(&R, m_VPInstruction<VPInstruction::WideIVStep>(

llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
352352
VPValue *Op1;
353353
if (match(&R, m_VPInstruction<VPInstruction::AnyOf>(m_VPValue(Op1))) ||
354354
match(&R, m_FirstActiveLane(m_VPValue(Op1))) ||
355+
match(&R, m_LastActiveLane(m_VPValue(Op1))) ||
355356
match(&R, m_VPInstruction<VPInstruction::ComputeAnyOfResult>(
356357
m_VPValue(), m_VPValue(), m_VPValue(Op1))) ||
357358
match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
@@ -364,17 +365,21 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
364365
continue;
365366
}
366367
VPValue *Op0;
367-
if (match(&R, m_VPInstruction<VPInstruction::ExtractLane>(
368-
m_VPValue(Op0), m_VPValue(Op1)))) {
368+
if (match(&R, m_ExtractLane(m_VPValue(Op0), m_VPValue(Op1)))) {
369369
addUniformForAllParts(cast<VPInstruction>(&R));
370370
for (unsigned Part = 1; Part != UF; ++Part)
371371
R.addOperand(getValueForPart(Op1, Part));
372372
continue;
373373
}
374374
if (match(&R, m_ExtractLastElement(m_VPValue(Op0))) ||
375-
match(&R, m_VPInstruction<VPInstruction::ExtractPenultimateElement>(
376-
m_VPValue(Op0)))) {
375+
match(&R, m_ExtractPenultimateElement(m_VPValue(Op0)))) {
377376
addUniformForAllParts(cast<VPSingleDefRecipe>(&R));
377+
if (isa<VPFirstOrderRecurrencePHIRecipe>(Op0)) {
378+
assert(match(&R, m_ExtractLastElement(m_VPValue())) &&
379+
"can only extract last element of FOR");
380+
continue;
381+
}
382+
378383
if (Plan.hasScalarVFOnly()) {
379384
auto *I = cast<VPInstruction>(&R);
380385
// Extracting from end with VF = 1 implies retrieving the last or

0 commit comments

Comments
 (0)