Skip to content

Commit f8eca64

Browse files
committed
Reapply "[LV] Use ExtractLane(LastActiveLane, V) live outs when tail-folding. (#149042)"
This reverts commit a6edeed. The following fixes have landed, addressing issues causing the original revert: * #169298 * #167897 * #168949 Original message: Building on top of #148817, introduce a new abstract LastActiveLane opcode that gets lowered to Not(Mask) → FirstActiveLane(NotMask) → Sub(result, 1). When folding the tail, update all extracts for uses outside the loop the extract the value of the last actice lane. See also #148603 PR: #149042
1 parent cabcb5a commit f8eca64

23 files changed

+1720
-614
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2095,24 +2095,6 @@ bool LoopVectorizationLegality::canFoldTailByMasking() const {
20952095
for (const auto &Reduction : getReductionVars())
20962096
ReductionLiveOuts.insert(Reduction.second.getLoopExitInstr());
20972097

2098-
// TODO: handle non-reduction outside users when tail is folded by masking.
2099-
for (auto *AE : AllowedExit) {
2100-
// Check that all users of allowed exit values are inside the loop or
2101-
// are the live-out of a reduction.
2102-
if (ReductionLiveOuts.count(AE))
2103-
continue;
2104-
for (User *U : AE->users()) {
2105-
Instruction *UI = cast<Instruction>(U);
2106-
if (TheLoop->contains(UI))
2107-
continue;
2108-
LLVM_DEBUG(
2109-
dbgs()
2110-
<< "LV: Cannot fold tail by masking, loop has an outside user for "
2111-
<< *UI << "\n");
2112-
return false;
2113-
}
2114-
}
2115-
21162098
for (const auto &Entry : getInductionVars()) {
21172099
PHINode *OrigPhi = Entry.first;
21182100
for (User *U : OrigPhi->users()) {

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8940,7 +8940,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
89408940
if (FinalReductionResult == U || Parent->getParent())
89418941
continue;
89428942
U->replaceUsesOfWith(OrigExitingVPV, FinalReductionResult);
8943-
if (match(U, m_ExtractLastElement(m_VPValue())))
8943+
if (match(U, m_CombineOr(m_ExtractLastElement(m_VPValue()),
8944+
m_ExtractLane(m_VPValue(), m_VPValue()))))
89448945
cast<VPInstruction>(U)->replaceAllUsesWith(FinalReductionResult);
89458946
}
89468947

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,6 +1099,13 @@ class LLVM_ABI_FOR_TEST VPInstruction : public VPRecipeWithIRFlags,
10991099
// Implemented with @llvm.experimental.cttz.elts, but returns the expected
11001100
// result even with operands that are all zeroes.
11011101
FirstActiveLane,
1102+
// Calculates the last active lane index of the vector predicate operands.
1103+
// The predicates must be prefix-masks (all 1s before all 0s). Used when
1104+
// tail-folding to extract the correct live-out value from the last active
1105+
// iteration. It produces the lane index across all unrolled iterations.
1106+
// Unrolling will add all copies of its original operand as additional
1107+
// operands.
1108+
LastActiveLane,
11021109

11031110
// The opcodes below are used for VPInstructionWithType.
11041111
//

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
115115
case VPInstruction::ExtractLane:
116116
return inferScalarType(R->getOperand(1));
117117
case VPInstruction::FirstActiveLane:
118+
case VPInstruction::LastActiveLane:
118119
return Type::getIntNTy(Ctx, 64);
119120
case VPInstruction::ExtractLastElement:
120121
case VPInstruction::ExtractLastLanePerPart:

llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,12 +398,24 @@ m_ExtractElement(const Op0_t &Op0, const Op1_t &Op1) {
398398
return m_VPInstruction<Instruction::ExtractElement>(Op0, Op1);
399399
}
400400

401+
template <typename Op0_t, typename Op1_t>
402+
inline VPInstruction_match<VPInstruction::ExtractLane, Op0_t, Op1_t>
403+
m_ExtractLane(const Op0_t &Op0, const Op1_t &Op1) {
404+
return m_VPInstruction<VPInstruction::ExtractLane>(Op0, Op1);
405+
}
406+
401407
template <typename Op0_t>
402408
inline VPInstruction_match<VPInstruction::ExtractLastLanePerPart, Op0_t>
403409
m_ExtractLastLanePerPart(const Op0_t &Op0) {
404410
return m_VPInstruction<VPInstruction::ExtractLastLanePerPart>(Op0);
405411
}
406412

413+
template <typename Op0_t>
414+
inline VPInstruction_match<VPInstruction::ExtractPenultimateElement, Op0_t>
415+
m_ExtractPenultimateElement(const Op0_t &Op0) {
416+
return m_VPInstruction<VPInstruction::ExtractPenultimateElement>(Op0);
417+
}
418+
407419
template <typename Op0_t, typename Op1_t, typename Op2_t>
408420
inline VPInstruction_match<VPInstruction::ActiveLaneMask, Op0_t, Op1_t, Op2_t>
409421
m_ActiveLaneMask(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
@@ -436,6 +448,16 @@ m_FirstActiveLane(const Op0_t &Op0) {
436448
return m_VPInstruction<VPInstruction::FirstActiveLane>(Op0);
437449
}
438450

451+
template <typename Op0_t>
452+
inline VPInstruction_match<VPInstruction::LastActiveLane, Op0_t>
453+
m_LastActiveLane(const Op0_t &Op0) {
454+
return m_VPInstruction<VPInstruction::LastActiveLane>(Op0);
455+
}
456+
457+
inline VPInstruction_match<VPInstruction::StepVector> m_StepVector() {
458+
return m_VPInstruction<VPInstruction::StepVector>();
459+
}
460+
439461
template <unsigned Opcode, typename Op0_t>
440462
inline AllRecipe_match<Opcode, Op0_t> m_Unary(const Op0_t &Op0) {
441463
return AllRecipe_match<Opcode, Op0_t>(Op0);

llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,6 @@ class VPPredicator {
4444
/// possibly inserting new recipes at \p Dst (using Builder's insertion point)
4545
VPValue *createEdgeMask(VPBasicBlock *Src, VPBasicBlock *Dst);
4646

47-
/// Returns the *entry* mask for \p VPBB.
48-
VPValue *getBlockInMask(VPBasicBlock *VPBB) const {
49-
return BlockMaskCache.lookup(VPBB);
50-
}
51-
5247
/// Record \p Mask as the *entry* mask of \p VPBB, which is expected to not
5348
/// already have a mask.
5449
void setBlockInMask(VPBasicBlock *VPBB, VPValue *Mask) {
@@ -68,6 +63,11 @@ class VPPredicator {
6863
}
6964

7065
public:
66+
/// Returns the *entry* mask for \p VPBB.
67+
VPValue *getBlockInMask(VPBasicBlock *VPBB) const {
68+
return BlockMaskCache.lookup(VPBB);
69+
}
70+
7171
/// Returns the precomputed predicate of the edge from \p Src to \p Dst.
7272
VPValue *getEdgeMask(const VPBasicBlock *Src, const VPBasicBlock *Dst) const {
7373
return EdgeMaskCache.lookup({Src, Dst});
@@ -301,5 +301,34 @@ VPlanTransforms::introduceMasksAndLinearize(VPlan &Plan, bool FoldTail) {
301301

302302
PrevVPBB = VPBB;
303303
}
304+
305+
// If we folded the tail and introduced a header mask, any extract of the
306+
// last element must be updated to extract from the last active lane of the
307+
// header mask instead (i.e., the lane corresponding to the last active
308+
// iteration).
309+
if (FoldTail) {
310+
assert(Plan.getExitBlocks().size() == 1 &&
311+
"only a single-exit block is supported currently");
312+
VPBasicBlock *EB = Plan.getExitBlocks().front();
313+
assert(EB->getSinglePredecessor() == Plan.getMiddleBlock() &&
314+
"the exit block must have middle block as single predecessor");
315+
316+
VPBuilder B(Plan.getMiddleBlock()->getTerminator());
317+
for (auto &P : EB->phis()) {
318+
auto *ExitIRI = cast<VPIRPhi>(&P);
319+
VPValue *Inc = ExitIRI->getIncomingValue(0);
320+
VPValue *Op;
321+
if (!match(Inc, m_ExtractLastElement(m_VPValue(Op))))
322+
continue;
323+
324+
// Compute the index of the last active lane.
325+
VPValue *HeaderMask = Predicator.getBlockInMask(Header);
326+
VPValue *LastActiveLane =
327+
B.createNaryOp(VPInstruction::LastActiveLane, HeaderMask);
328+
auto *Ext =
329+
B.createNaryOp(VPInstruction::ExtractLane, {LastActiveLane, Op});
330+
Inc->replaceAllUsesWith(Ext);
331+
}
332+
}
304333
return Predicator.getBlockMaskCache();
305334
}

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,6 @@ unsigned VPInstruction::getNumOperandsForOpcode(unsigned Opcode) {
440440
case VPInstruction::ExtractLastElement:
441441
case VPInstruction::ExtractLastLanePerPart:
442442
case VPInstruction::ExtractPenultimateElement:
443-
case VPInstruction::FirstActiveLane:
444443
case VPInstruction::Not:
445444
case VPInstruction::ResumeForEpilogue:
446445
case VPInstruction::Unpack:
@@ -470,6 +469,8 @@ unsigned VPInstruction::getNumOperandsForOpcode(unsigned Opcode) {
470469
case Instruction::PHI:
471470
case Instruction::Switch:
472471
case VPInstruction::AnyOf:
472+
case VPInstruction::FirstActiveLane:
473+
case VPInstruction::LastActiveLane:
473474
case VPInstruction::SLPLoad:
474475
case VPInstruction::SLPStore:
475476
// Cannot determine the number of operands from the opcode.
@@ -1055,6 +1056,29 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
10551056
{PredTy, Type::getInt1Ty(Ctx.LLVMCtx)});
10561057
return Ctx.TTI.getIntrinsicInstrCost(Attrs, Ctx.CostKind);
10571058
}
1059+
case VPInstruction::LastActiveLane: {
1060+
Type *ScalarTy = Ctx.Types.inferScalarType(getOperand(0));
1061+
if (VF.isScalar())
1062+
return Ctx.TTI.getCmpSelInstrCost(Instruction::ICmp, ScalarTy,
1063+
CmpInst::makeCmpResultType(ScalarTy),
1064+
CmpInst::ICMP_EQ, Ctx.CostKind);
1065+
// Calculate the cost of determining the lane index: NOT + cttz_elts + SUB.
1066+
auto *PredTy = toVectorTy(ScalarTy, VF);
1067+
IntrinsicCostAttributes Attrs(Intrinsic::experimental_cttz_elts,
1068+
Type::getInt64Ty(Ctx.LLVMCtx),
1069+
{PredTy, Type::getInt1Ty(Ctx.LLVMCtx)});
1070+
InstructionCost Cost = Ctx.TTI.getIntrinsicInstrCost(Attrs, Ctx.CostKind);
1071+
// Add cost of NOT operation on the predicate.
1072+
Cost += Ctx.TTI.getArithmeticInstrCost(
1073+
Instruction::Xor, PredTy, Ctx.CostKind,
1074+
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
1075+
{TargetTransformInfo::OK_UniformConstantValue,
1076+
TargetTransformInfo::OP_None});
1077+
// Add cost of SUB operation on the index.
1078+
Cost += Ctx.TTI.getArithmeticInstrCost(
1079+
Instruction::Sub, Type::getInt64Ty(Ctx.LLVMCtx), Ctx.CostKind);
1080+
return Cost;
1081+
}
10581082
case VPInstruction::FirstOrderRecurrenceSplice: {
10591083
assert(VF.isVector() && "Scalar FirstOrderRecurrenceSplice?");
10601084
SmallVector<int> Mask(VF.getKnownMinValue());
@@ -1109,6 +1133,7 @@ bool VPInstruction::isVectorToScalar() const {
11091133
getOpcode() == Instruction::ExtractElement ||
11101134
getOpcode() == VPInstruction::ExtractLane ||
11111135
getOpcode() == VPInstruction::FirstActiveLane ||
1136+
getOpcode() == VPInstruction::LastActiveLane ||
11121137
getOpcode() == VPInstruction::ComputeAnyOfResult ||
11131138
getOpcode() == VPInstruction::ComputeFindIVResult ||
11141139
getOpcode() == VPInstruction::ComputeReductionResult ||
@@ -1176,6 +1201,7 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
11761201
case VPInstruction::ActiveLaneMask:
11771202
case VPInstruction::ExplicitVectorLength:
11781203
case VPInstruction::FirstActiveLane:
1204+
case VPInstruction::LastActiveLane:
11791205
case VPInstruction::FirstOrderRecurrenceSplice:
11801206
case VPInstruction::LogicalAnd:
11811207
case VPInstruction::Not:
@@ -1352,6 +1378,9 @@ void VPInstruction::printRecipe(raw_ostream &O, const Twine &Indent,
13521378
case VPInstruction::FirstActiveLane:
13531379
O << "first-active-lane";
13541380
break;
1381+
case VPInstruction::LastActiveLane:
1382+
O << "last-active-lane";
1383+
break;
13551384
case VPInstruction::ReductionStartVector:
13561385
O << "reduction-start-vector";
13571386
break;

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -861,8 +861,8 @@ static VPValue *optimizeEarlyExitInductionUser(VPlan &Plan,
861861
VPValue *Op,
862862
ScalarEvolution &SE) {
863863
VPValue *Incoming, *Mask;
864-
if (!match(Op, m_VPInstruction<VPInstruction::ExtractLane>(
865-
m_FirstActiveLane(m_VPValue(Mask)), m_VPValue(Incoming))))
864+
if (!match(Op, m_ExtractLane(m_FirstActiveLane(m_VPValue(Mask)),
865+
m_VPValue(Incoming))))
866866
return nullptr;
867867

868868
auto *WideIV = getOptimizableIVOf(Incoming, SE);
@@ -1362,8 +1362,7 @@ static void simplifyRecipe(VPSingleDefRecipe *Def, VPTypeAnalysis &TypeInfo) {
13621362
}
13631363

13641364
// Look through ExtractPenultimateElement (BuildVector ....).
1365-
if (match(Def, m_VPInstruction<VPInstruction::ExtractPenultimateElement>(
1366-
m_BuildVector()))) {
1365+
if (match(Def, m_ExtractPenultimateElement(m_BuildVector()))) {
13671366
auto *BuildVector = cast<VPInstruction>(Def->getOperand(0));
13681367
Def->replaceAllUsesWith(
13691368
BuildVector->getOperand(BuildVector->getNumOperands() - 2));
@@ -2175,6 +2174,32 @@ bool VPlanTransforms::adjustFixedOrderRecurrences(VPlan &Plan,
21752174
// Set the first operand of RecurSplice to FOR again, after replacing
21762175
// all users.
21772176
RecurSplice->setOperand(0, FOR);
2177+
2178+
// Check for users extracting at the penultimate active lane of the FOR.
2179+
// If only a single lane is active in the current iteration, we need to
2180+
// select the last element from the previous iteration (from the FOR phi
2181+
// directly).
2182+
for (VPUser *U : RecurSplice->users()) {
2183+
if (!match(U, m_ExtractLane(m_LastActiveLane(m_VPValue()),
2184+
m_Specific(RecurSplice))))
2185+
continue;
2186+
2187+
VPBuilder B(cast<VPInstruction>(U));
2188+
VPValue *LastActiveLane = cast<VPInstruction>(U)->getOperand(0);
2189+
Type *I64Ty = Type::getInt64Ty(Plan.getContext());
2190+
VPValue *Zero = Plan.getOrAddLiveIn(ConstantInt::get(I64Ty, 0));
2191+
VPValue *One = Plan.getOrAddLiveIn(ConstantInt::get(I64Ty, 1));
2192+
VPValue *PenultimateIndex =
2193+
B.createNaryOp(Instruction::Sub, {LastActiveLane, One});
2194+
VPValue *PenultimateLastIter =
2195+
B.createNaryOp(VPInstruction::ExtractLane,
2196+
{PenultimateIndex, FOR->getBackedgeValue()});
2197+
VPValue *LastPrevIter =
2198+
B.createNaryOp(VPInstruction::ExtractLastElement, FOR);
2199+
VPValue *Cmp = B.createICmp(CmpInst::ICMP_EQ, LastActiveLane, Zero);
2200+
VPValue *Sel = B.createSelect(Cmp, LastPrevIter, PenultimateLastIter);
2201+
cast<VPInstruction>(U)->replaceAllUsesWith(Sel);
2202+
}
21782203
}
21792204
return true;
21802205
}
@@ -3563,6 +3588,34 @@ void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan) {
35633588
ToRemove.push_back(Expr);
35643589
}
35653590

3591+
// Expand LastActiveLane into Not + FirstActiveLane + Sub.
3592+
auto *LastActiveL = dyn_cast<VPInstruction>(&R);
3593+
if (LastActiveL &&
3594+
LastActiveL->getOpcode() == VPInstruction::LastActiveLane) {
3595+
// Create Not(Mask) for all operands.
3596+
SmallVector<VPValue *, 2> NotMasks;
3597+
for (VPValue *Op : LastActiveL->operands()) {
3598+
VPValue *NotMask = Builder.createNot(Op, LastActiveL->getDebugLoc());
3599+
NotMasks.push_back(NotMask);
3600+
}
3601+
3602+
// Create FirstActiveLane on the inverted masks.
3603+
VPValue *FirstInactiveLane = Builder.createNaryOp(
3604+
VPInstruction::FirstActiveLane, NotMasks,
3605+
LastActiveL->getDebugLoc(), "first.inactive.lane");
3606+
3607+
// Subtract 1 to get the last active lane.
3608+
VPValue *One = Plan.getOrAddLiveIn(
3609+
ConstantInt::get(Type::getInt64Ty(Plan.getContext()), 1));
3610+
VPValue *LastLane = Builder.createNaryOp(
3611+
Instruction::Sub, {FirstInactiveLane, One},
3612+
LastActiveL->getDebugLoc(), "last.active.lane");
3613+
3614+
LastActiveL->replaceAllUsesWith(LastLane);
3615+
ToRemove.push_back(LastActiveL);
3616+
continue;
3617+
}
3618+
35663619
VPValue *VectorStep;
35673620
VPValue *ScalarStep;
35683621
if (!match(&R, m_VPInstruction<VPInstruction::WideIVStep>(

llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
352352
VPValue *Op1;
353353
if (match(&R, m_VPInstruction<VPInstruction::AnyOf>(m_VPValue(Op1))) ||
354354
match(&R, m_FirstActiveLane(m_VPValue(Op1))) ||
355+
match(&R, m_LastActiveLane(m_VPValue(Op1))) ||
355356
match(&R, m_VPInstruction<VPInstruction::ComputeAnyOfResult>(
356357
m_VPValue(), m_VPValue(), m_VPValue(Op1))) ||
357358
match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
@@ -364,17 +365,21 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
364365
continue;
365366
}
366367
VPValue *Op0;
367-
if (match(&R, m_VPInstruction<VPInstruction::ExtractLane>(
368-
m_VPValue(Op0), m_VPValue(Op1)))) {
368+
if (match(&R, m_ExtractLane(m_VPValue(Op0), m_VPValue(Op1)))) {
369369
addUniformForAllParts(cast<VPInstruction>(&R));
370370
for (unsigned Part = 1; Part != UF; ++Part)
371371
R.addOperand(getValueForPart(Op1, Part));
372372
continue;
373373
}
374374
if (match(&R, m_ExtractLastElement(m_VPValue(Op0))) ||
375-
match(&R, m_VPInstruction<VPInstruction::ExtractPenultimateElement>(
376-
m_VPValue(Op0)))) {
375+
match(&R, m_ExtractPenultimateElement(m_VPValue(Op0)))) {
377376
addUniformForAllParts(cast<VPSingleDefRecipe>(&R));
377+
if (isa<VPFirstOrderRecurrencePHIRecipe>(Op0)) {
378+
assert(match(&R, m_ExtractLastElement(m_VPValue())) &&
379+
"can only extract last element of FOR");
380+
continue;
381+
}
382+
378383
if (Plan.hasScalarVFOnly()) {
379384
auto *I = cast<VPInstruction>(&R);
380385
// Extracting from end with VF = 1 implies retrieving the last or

0 commit comments

Comments
 (0)