Skip to content

Commit 3fc7419

Browse files
authored
[VPlan] Replace ExtractLast(Elem|LanePerPart) with ExtractLast(Lane/Part) (#164124)
Replace ExtractLastElement and ExtractLastLanePerPart with more generic and specific ExtractLastLane and ExtractLastPart, which model distinct parts of extracting across parts and lanes. ExtractLastElement == ExtractLastLane(ExtractLastPart) and ExtractLastLanePerPart == ExtractLastLane, the latter clarifying the name of the opcode. A new m_ExtractLastElement matcher is provided for convenience. The patch should be NFC modulo printing changes. PR: #164124
1 parent 9e7ce77 commit 3fc7419

File tree

14 files changed

+135
-112
lines changed

14 files changed

+135
-112
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8974,14 +8974,19 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
89748974
}
89758975

89768976
// Update all users outside the vector region. Also replace redundant
8977-
// ExtractLastElement.
8977+
// extracts.
89788978
for (auto *U : to_vector(OrigExitingVPV->users())) {
89798979
auto *Parent = cast<VPRecipeBase>(U)->getParent();
89808980
if (FinalReductionResult == U || Parent->getParent())
89818981
continue;
89828982
U->replaceUsesOfWith(OrigExitingVPV, FinalReductionResult);
8983-
if (match(U, m_CombineOr(m_ExtractLastElement(m_VPValue()),
8984-
m_ExtractLane(m_VPValue(), m_VPValue()))))
8983+
8984+
// Look through ExtractLastPart.
8985+
if (match(U, m_ExtractLastPart(m_VPValue())))
8986+
U = cast<VPInstruction>(U)->getSingleUser();
8987+
8988+
if (match(U, m_CombineOr(m_ExtractLane(m_VPValue(), m_VPValue()),
8989+
m_ExtractLastLane(m_VPValue()))))
89858990
cast<VPInstruction>(U)->replaceAllUsesWith(FinalReductionResult);
89868991
}
89878992

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,12 +1074,10 @@ class LLVM_ABI_FOR_TEST VPInstruction : public VPRecipeWithIRFlags,
10741074
ComputeAnyOfResult,
10751075
ComputeFindIVResult,
10761076
ComputeReductionResult,
1077-
// Extracts the last lane from its operand if it is a vector, or the last
1078-
// part if scalar. In the latter case, the recipe will be removed during
1079-
// unrolling.
1080-
ExtractLastElement,
1081-
// Extracts the last lane for each part from its operand.
1082-
ExtractLastLanePerPart,
1077+
// Extracts the last part of its operand. Removed during unrolling.
1078+
ExtractLastPart,
1079+
// Extracts the last lane of its vector operand, per part.
1080+
ExtractLastLane,
10831081
// Extracts the second-to-last lane from its operand or the second-to-last
10841082
// part if it is scalar. In the latter case, the recipe will be removed
10851083
// during unrolling.
@@ -1466,10 +1464,10 @@ class VPIRInstruction : public VPRecipeBase {
14661464
return true;
14671465
}
14681466

1469-
/// Update the recipes first operand to the last lane of the operand using \p
1470-
/// Builder. Must only be used for VPIRInstructions with at least one operand
1471-
/// wrapping a PHINode.
1472-
void extractLastLaneOfFirstOperand(VPBuilder &Builder);
1467+
/// Update the recipe's first operand to the last lane of the last part of the
1468+
/// operand using \p Builder. Must only be used for VPIRInstructions with at
1469+
/// least one operand wrapping a PHINode.
1470+
void extractLastLaneOfLastPartOfFirstOperand(VPBuilder &Builder);
14731471

14741472
protected:
14751473
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,15 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
119119
case VPInstruction::FirstActiveLane:
120120
case VPInstruction::LastActiveLane:
121121
return Type::getIntNTy(Ctx, 64);
122-
case VPInstruction::ExtractLastElement:
123-
case VPInstruction::ExtractLastLanePerPart:
122+
case VPInstruction::ExtractLastLane:
124123
case VPInstruction::ExtractPenultimateElement: {
125124
Type *BaseTy = inferScalarType(R->getOperand(0));
126125
if (auto *VecTy = dyn_cast<VectorType>(BaseTy))
127126
return VecTy->getElementType();
128127
return BaseTy;
129128
}
129+
case VPInstruction::ExtractLastPart:
130+
return inferScalarType(R->getOperand(0));
130131
case VPInstruction::LogicalAnd:
131132
assert(inferScalarType(R->getOperand(0))->isIntegerTy(1) &&
132133
inferScalarType(R->getOperand(1))->isIntegerTy(1) &&
@@ -540,11 +541,14 @@ SmallVector<VPRegisterUsage, 8> llvm::calculateRegisterUsageForPlan(
540541
SmallMapVector<unsigned, unsigned, 4> RegUsage;
541542

542543
for (auto *VPV : OpenIntervals) {
543-
// Skip values that weren't present in the original loop.
544-
// TODO: Remove after removing the legacy
544+
// Skip artificial values or values that weren't present in the original
545+
// loop.
546+
// TODO: Remove skipping values that weren't present in the original
547+
// loop after removing the legacy
545548
// LoopVectorizationCostModel::calculateRegisterUsage
546549
if (isa<VPVectorPointerRecipe, VPVectorEndPointerRecipe,
547-
VPBranchOnMaskRecipe>(VPV))
550+
VPBranchOnMaskRecipe>(VPV) ||
551+
match(VPV, m_ExtractLastPart(m_VPValue())))
548552
continue;
549553

550554
if (VFs[J].isScalar() ||

llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ static void createExtractsForLiveOuts(VPlan &Plan, VPBasicBlock *MiddleVPBB) {
499499
ExitIRI->getParent()->getSinglePredecessor() == MiddleVPBB &&
500500
"exit values from early exits must be fixed when branch to "
501501
"early-exit is added");
502-
ExitIRI->extractLastLaneOfFirstOperand(B);
502+
ExitIRI->extractLastLaneOfLastPartOfFirstOperand(B);
503503
}
504504
}
505505
}

llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -387,9 +387,9 @@ m_EVL(const Op0_t &Op0) {
387387
}
388388

389389
template <typename Op0_t>
390-
inline VPInstruction_match<VPInstruction::ExtractLastElement, Op0_t>
391-
m_ExtractLastElement(const Op0_t &Op0) {
392-
return m_VPInstruction<VPInstruction::ExtractLastElement>(Op0);
390+
inline VPInstruction_match<VPInstruction::ExtractLastLane, Op0_t>
391+
m_ExtractLastLane(const Op0_t &Op0) {
392+
return m_VPInstruction<VPInstruction::ExtractLastLane>(Op0);
393393
}
394394

395395
template <typename Op0_t, typename Op1_t>
@@ -405,9 +405,17 @@ m_ExtractLane(const Op0_t &Op0, const Op1_t &Op1) {
405405
}
406406

407407
template <typename Op0_t>
408-
inline VPInstruction_match<VPInstruction::ExtractLastLanePerPart, Op0_t>
409-
m_ExtractLastLanePerPart(const Op0_t &Op0) {
410-
return m_VPInstruction<VPInstruction::ExtractLastLanePerPart>(Op0);
408+
inline VPInstruction_match<VPInstruction::ExtractLastPart, Op0_t>
409+
m_ExtractLastPart(const Op0_t &Op0) {
410+
return m_VPInstruction<VPInstruction::ExtractLastPart>(Op0);
411+
}
412+
413+
template <typename Op0_t>
414+
inline VPInstruction_match<
415+
VPInstruction::ExtractLastLane,
416+
VPInstruction_match<VPInstruction::ExtractLastPart, Op0_t>>
417+
m_ExtractLastLaneOfLastPart(const Op0_t &Op0) {
418+
return m_ExtractLastLane(m_ExtractLastPart(Op0));
411419
}
412420

413421
template <typename Op0_t>

llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -314,11 +314,9 @@ VPlanTransforms::introduceMasksAndLinearize(VPlan &Plan, bool FoldTail) {
314314
"the exit block must have middle block as single predecessor");
315315

316316
VPBuilder B(Plan.getMiddleBlock()->getTerminator());
317-
for (auto &P : EB->phis()) {
318-
auto *ExitIRI = cast<VPIRPhi>(&P);
319-
VPValue *Inc = ExitIRI->getIncomingValue(0);
317+
for (VPRecipeBase &R : *Plan.getMiddleBlock()) {
320318
VPValue *Op;
321-
if (!match(Inc, m_ExtractLastElement(m_VPValue(Op))))
319+
if (!match(&R, m_ExtractLastLane(m_ExtractLastPart(m_VPValue(Op)))))
322320
continue;
323321

324322
// Compute the index of the last active lane.
@@ -327,7 +325,7 @@ VPlanTransforms::introduceMasksAndLinearize(VPlan &Plan, bool FoldTail) {
327325
B.createNaryOp(VPInstruction::LastActiveLane, HeaderMask);
328326
auto *Ext =
329327
B.createNaryOp(VPInstruction::ExtractLane, {LastActiveLane, Op});
330-
Inc->replaceAllUsesWith(Ext);
328+
R.getVPSingleValue()->replaceAllUsesWith(Ext);
331329
}
332330
}
333331
return Predicator.getBlockMaskCache();

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -439,8 +439,8 @@ unsigned VPInstruction::getNumOperandsForOpcode(unsigned Opcode) {
439439
case VPInstruction::CalculateTripCountMinusVF:
440440
case VPInstruction::CanonicalIVIncrementForPart:
441441
case VPInstruction::ExplicitVectorLength:
442-
case VPInstruction::ExtractLastElement:
443-
case VPInstruction::ExtractLastLanePerPart:
442+
case VPInstruction::ExtractLastLane:
443+
case VPInstruction::ExtractLastPart:
444444
case VPInstruction::ExtractPenultimateElement:
445445
case VPInstruction::Not:
446446
case VPInstruction::ResumeForEpilogue:
@@ -813,8 +813,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
813813

814814
return ReducedPartRdx;
815815
}
816-
case VPInstruction::ExtractLastLanePerPart:
817-
case VPInstruction::ExtractLastElement:
816+
case VPInstruction::ExtractLastLane:
818817
case VPInstruction::ExtractPenultimateElement: {
819818
unsigned Offset =
820819
getOpcode() == VPInstruction::ExtractPenultimateElement ? 2 : 1;
@@ -825,6 +824,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
825824
// Extract lane VF - Offset from the operand.
826825
Res = State.get(getOperand(0), VPLane::getLaneFromEnd(State.VF, Offset));
827826
} else {
827+
// TODO: Remove ExtractLastLane for scalar VFs.
828828
assert(Offset <= 1 && "invalid offset to extract from");
829829
Res = State.get(getOperand(0));
830830
}
@@ -1103,7 +1103,7 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
11031103
I32Ty, {Arg0Ty, I32Ty, I1Ty});
11041104
return Ctx.TTI.getIntrinsicInstrCost(Attrs, Ctx.CostKind);
11051105
}
1106-
case VPInstruction::ExtractLastElement: {
1106+
case VPInstruction::ExtractLastLane: {
11071107
// Add on the cost of extracting the element.
11081108
auto *VecTy = toVectorTy(Ctx.Types.inferScalarType(getOperand(0)), VF);
11091109
return Ctx.TTI.getIndexedVectorInstrCostFromEnd(Instruction::ExtractElement,
@@ -1123,8 +1123,7 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
11231123
}
11241124

11251125
bool VPInstruction::isVectorToScalar() const {
1126-
return getOpcode() == VPInstruction::ExtractLastElement ||
1127-
getOpcode() == VPInstruction::ExtractLastLanePerPart ||
1126+
return getOpcode() == VPInstruction::ExtractLastLane ||
11281127
getOpcode() == VPInstruction::ExtractPenultimateElement ||
11291128
getOpcode() == Instruction::ExtractElement ||
11301129
getOpcode() == VPInstruction::ExtractLane ||
@@ -1191,8 +1190,8 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
11911190
case VPInstruction::CalculateTripCountMinusVF:
11921191
case VPInstruction::CanonicalIVIncrementForPart:
11931192
case VPInstruction::ExtractLane:
1194-
case VPInstruction::ExtractLastElement:
1195-
case VPInstruction::ExtractLastLanePerPart:
1193+
case VPInstruction::ExtractLastLane:
1194+
case VPInstruction::ExtractLastPart:
11961195
case VPInstruction::ExtractPenultimateElement:
11971196
case VPInstruction::ActiveLaneMask:
11981197
case VPInstruction::ExplicitVectorLength:
@@ -1341,11 +1340,11 @@ void VPInstruction::printRecipe(raw_ostream &O, const Twine &Indent,
13411340
case VPInstruction::ExtractLane:
13421341
O << "extract-lane";
13431342
break;
1344-
case VPInstruction::ExtractLastElement:
1345-
O << "extract-last-element";
1343+
case VPInstruction::ExtractLastLane:
1344+
O << "extract-last-lane";
13461345
break;
1347-
case VPInstruction::ExtractLastLanePerPart:
1348-
O << "extract-last-lane-per-part";
1346+
case VPInstruction::ExtractLastPart:
1347+
O << "extract-last-part";
13491348
break;
13501349
case VPInstruction::ExtractPenultimateElement:
13511350
O << "extract-penultimate-element";
@@ -1498,15 +1497,17 @@ InstructionCost VPIRInstruction::computeCost(ElementCount VF,
14981497
return 0;
14991498
}
15001499

1501-
void VPIRInstruction::extractLastLaneOfFirstOperand(VPBuilder &Builder) {
1500+
void VPIRInstruction::extractLastLaneOfLastPartOfFirstOperand(
1501+
VPBuilder &Builder) {
15021502
assert(isa<PHINode>(getInstruction()) &&
15031503
"can only update exiting operands to phi nodes");
15041504
assert(getNumOperands() > 0 && "must have at least one operand");
15051505
VPValue *Exiting = getOperand(0);
15061506
if (Exiting->isLiveIn())
15071507
return;
15081508

1509-
Exiting = Builder.createNaryOp(VPInstruction::ExtractLastElement, {Exiting});
1509+
Exiting = Builder.createNaryOp(VPInstruction::ExtractLastPart, Exiting);
1510+
Exiting = Builder.createNaryOp(VPInstruction::ExtractLastLane, Exiting);
15101511
setOperand(0, Exiting);
15111512
}
15121513

0 commit comments

Comments
 (0)