Skip to content

Commit dbb1410

Browse files
committed
[VPlan] Replace ExtractLast(Elem|LanePerPart) with ExtractLast(Lane/Part)
1 parent 65666b2 commit dbb1410

13 files changed

+141
-112
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8963,14 +8963,23 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
89638963
}
89648964

89658965
// Update all users outside the vector region. Also replace redundant
8966-
// ExtractLastElement.
8966+
// extracts.
89678967
for (auto *U : to_vector(OrigExitingVPV->users())) {
89688968
auto *Parent = cast<VPRecipeBase>(U)->getParent();
89698969
if (FinalReductionResult == U || Parent->getParent())
89708970
continue;
89718971
U->replaceUsesOfWith(OrigExitingVPV, FinalReductionResult);
8972-
if (match(U, m_CombineOr(m_ExtractLastElement(m_VPValue()),
8973-
m_ExtractLane(m_VPValue(), m_VPValue()))))
8972+
8973+
// Look through ExtractLastPart.
8974+
if (match(U, m_ExtractLastPart(m_VPValue()))) {
8975+
auto *ExtractPart = cast<VPInstruction>(U);
8976+
if (ExtractPart->getNumUsers() != 1)
8977+
continue;
8978+
U = *ExtractPart->user_begin();
8979+
}
8980+
8981+
if (match(U, m_CombineOr(m_ExtractLane(m_VPValue(), m_VPValue()),
8982+
m_ExtractLastLane(m_VPValue()))))
89748983
cast<VPInstruction>(U)->replaceAllUsesWith(FinalReductionResult);
89758984
}
89768985

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,12 +1069,10 @@ class LLVM_ABI_FOR_TEST VPInstruction : public VPRecipeWithIRFlags,
10691069
ComputeAnyOfResult,
10701070
ComputeFindIVResult,
10711071
ComputeReductionResult,
1072-
// Extracts the last lane from its operand if it is a vector, or the last
1073-
// part if scalar. In the latter case, the recipe will be removed during
1074-
// unrolling.
1075-
ExtractLastElement,
1076-
// Extracts the last lane for each part from its operand.
1077-
ExtractLastLanePerPart,
1072+
// Extracts the last part of its operand. Removed during unrolling.
1073+
ExtractLastPart,
1074+
// Extracts the last lane of its vector operand, per part.
1075+
ExtractLastLane,
10781076
// Extracts the second-to-last lane from its operand or the second-to-last
10791077
// part if it is scalar. In the latter case, the recipe will be removed
10801078
// during unrolling.
@@ -1461,10 +1459,10 @@ class VPIRInstruction : public VPRecipeBase {
14611459
return true;
14621460
}
14631461

1464-
/// Update the recipes first operand to the last lane of the operand using \p
1465-
/// Builder. Must only be used for VPIRInstructions with at least one operand
1466-
/// wrapping a PHINode.
1467-
void extractLastLaneOfFirstOperand(VPBuilder &Builder);
1462+
/// Update the recipe's first operand to the final lane of the operand using
1463+
/// \p Builder. Must only be used for VPIRInstructions with at least one
1464+
/// operand wrapping a PHINode.
1465+
void extractFinalLaneOfFirstOperand(VPBuilder &Builder);
14681466

14691467
protected:
14701468
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,18 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
119119
case VPInstruction::FirstActiveLane:
120120
case VPInstruction::LastActiveLane:
121121
return Type::getIntNTy(Ctx, 64);
122-
case VPInstruction::ExtractLastElement:
123-
case VPInstruction::ExtractLastLanePerPart:
122+
case VPInstruction::ExtractLastLane:
124123
case VPInstruction::ExtractPenultimateElement: {
125124
Type *BaseTy = inferScalarType(R->getOperand(0));
126125
if (auto *VecTy = dyn_cast<VectorType>(BaseTy))
127126
return VecTy->getElementType();
128127
return BaseTy;
129128
}
129+
case VPInstruction::ExtractLastPart: {
130+
// Element type of ExtractLastPart is equal to the element type of its
131+
// operand.
132+
return inferScalarType(R->getOperand(0));
133+
}
130134
case VPInstruction::LogicalAnd:
131135
assert(inferScalarType(R->getOperand(0))->isIntegerTy(1) &&
132136
inferScalarType(R->getOperand(1))->isIntegerTy(1) &&
@@ -540,11 +544,14 @@ SmallVector<VPRegisterUsage, 8> llvm::calculateRegisterUsageForPlan(
540544
SmallMapVector<unsigned, unsigned, 4> RegUsage;
541545

542546
for (auto *VPV : OpenIntervals) {
543-
// Skip values that weren't present in the original loop.
544-
// TODO: Remove after removing the legacy
547+
// Skip artificial values or values that weren't present in the original
548+
// loop.
549+
// TODO: Remove skipping values that weren't present in the original
550+
// loop after removing the legacy
545551
// LoopVectorizationCostModel::calculateRegisterUsage
546552
if (isa<VPVectorPointerRecipe, VPVectorEndPointerRecipe,
547-
VPBranchOnMaskRecipe>(VPV))
553+
VPBranchOnMaskRecipe>(VPV) ||
554+
match(VPV, m_ExtractLastPart(m_VPValue())))
548555
continue;
549556

550557
if (VFs[J].isScalar() ||

llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ static void createExtractsForLiveOuts(VPlan &Plan, VPBasicBlock *MiddleVPBB) {
499499
ExitIRI->getParent()->getSinglePredecessor() == MiddleVPBB &&
500500
"exit values from early exits must be fixed when branch to "
501501
"early-exit is added");
502-
ExitIRI->extractLastLaneOfFirstOperand(B);
502+
ExitIRI->extractFinalLaneOfFirstOperand(B);
503503
}
504504
}
505505
}

llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -387,9 +387,9 @@ m_EVL(const Op0_t &Op0) {
387387
}
388388

389389
template <typename Op0_t>
390-
inline VPInstruction_match<VPInstruction::ExtractLastElement, Op0_t>
391-
m_ExtractLastElement(const Op0_t &Op0) {
392-
return m_VPInstruction<VPInstruction::ExtractLastElement>(Op0);
390+
inline VPInstruction_match<VPInstruction::ExtractLastLane, Op0_t>
391+
m_ExtractLastLane(const Op0_t &Op0) {
392+
return m_VPInstruction<VPInstruction::ExtractLastLane>(Op0);
393393
}
394394

395395
template <typename Op0_t, typename Op1_t>
@@ -405,9 +405,17 @@ m_ExtractLane(const Op0_t &Op0, const Op1_t &Op1) {
405405
}
406406

407407
template <typename Op0_t>
408-
inline VPInstruction_match<VPInstruction::ExtractLastLanePerPart, Op0_t>
409-
m_ExtractLastLanePerPart(const Op0_t &Op0) {
410-
return m_VPInstruction<VPInstruction::ExtractLastLanePerPart>(Op0);
408+
inline VPInstruction_match<VPInstruction::ExtractLastPart, Op0_t>
409+
m_ExtractLastPart(const Op0_t &Op0) {
410+
return m_VPInstruction<VPInstruction::ExtractLastPart>(Op0);
411+
}
412+
413+
template <typename Op0_t>
414+
inline VPInstruction_match<
415+
VPInstruction::ExtractLastLane,
416+
VPInstruction_match<VPInstruction::ExtractLastPart, Op0_t>>
417+
m_ExtractFinalLane(const Op0_t &Op0) {
418+
return m_ExtractLastLane(m_ExtractLastPart(Op0));
411419
}
412420

413421
template <typename Op0_t>

llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -314,11 +314,9 @@ VPlanTransforms::introduceMasksAndLinearize(VPlan &Plan, bool FoldTail) {
314314
"the exit block must have middle block as single predecessor");
315315

316316
VPBuilder B(Plan.getMiddleBlock()->getTerminator());
317-
for (auto &P : EB->phis()) {
318-
auto *ExitIRI = cast<VPIRPhi>(&P);
319-
VPValue *Inc = ExitIRI->getIncomingValue(0);
317+
for (VPRecipeBase &R : *Plan.getMiddleBlock()) {
320318
VPValue *Op;
321-
if (!match(Inc, m_ExtractLastElement(m_VPValue(Op))))
319+
if (!match(&R, m_ExtractLastLane(m_ExtractLastPart(m_VPValue(Op)))))
322320
continue;
323321

324322
// Compute the index of the last active lane.
@@ -327,7 +325,7 @@ VPlanTransforms::introduceMasksAndLinearize(VPlan &Plan, bool FoldTail) {
327325
B.createNaryOp(VPInstruction::LastActiveLane, HeaderMask);
328326
auto *Ext =
329327
B.createNaryOp(VPInstruction::ExtractLane, {LastActiveLane, Op});
330-
Inc->replaceAllUsesWith(Ext);
328+
R.getVPSingleValue()->replaceAllUsesWith(Ext);
331329
}
332330
}
333331
return Predicator.getBlockMaskCache();

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -437,8 +437,8 @@ unsigned VPInstruction::getNumOperandsForOpcode(unsigned Opcode) {
437437
case VPInstruction::CalculateTripCountMinusVF:
438438
case VPInstruction::CanonicalIVIncrementForPart:
439439
case VPInstruction::ExplicitVectorLength:
440-
case VPInstruction::ExtractLastElement:
441-
case VPInstruction::ExtractLastLanePerPart:
440+
case VPInstruction::ExtractLastLane:
441+
case VPInstruction::ExtractLastPart:
442442
case VPInstruction::ExtractPenultimateElement:
443443
case VPInstruction::Not:
444444
case VPInstruction::ResumeForEpilogue:
@@ -815,8 +815,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
815815

816816
return ReducedPartRdx;
817817
}
818-
case VPInstruction::ExtractLastLanePerPart:
819-
case VPInstruction::ExtractLastElement:
818+
case VPInstruction::ExtractLastLane:
820819
case VPInstruction::ExtractPenultimateElement: {
821820
unsigned Offset =
822821
getOpcode() == VPInstruction::ExtractPenultimateElement ? 2 : 1;
@@ -827,6 +826,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
827826
// Extract lane VF - Offset from the operand.
828827
Res = State.get(getOperand(0), VPLane::getLaneFromEnd(State.VF, Offset));
829828
} else {
829+
// TODO: Remove ExtractLastLane for scalar VFs.
830830
assert(Offset <= 1 && "invalid offset to extract from");
831831
Res = State.get(getOperand(0));
832832
}
@@ -1107,7 +1107,7 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
11071107
I32Ty, {Arg0Ty, I32Ty, I1Ty});
11081108
return Ctx.TTI.getIntrinsicInstrCost(Attrs, Ctx.CostKind);
11091109
}
1110-
case VPInstruction::ExtractLastElement: {
1110+
case VPInstruction::ExtractLastLane: {
11111111
// Add on the cost of extracting the element.
11121112
auto *VecTy = toVectorTy(Ctx.Types.inferScalarType(getOperand(0)), VF);
11131113
return Ctx.TTI.getIndexedVectorInstrCostFromEnd(Instruction::ExtractElement,
@@ -1127,8 +1127,7 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
11271127
}
11281128

11291129
bool VPInstruction::isVectorToScalar() const {
1130-
return getOpcode() == VPInstruction::ExtractLastElement ||
1131-
getOpcode() == VPInstruction::ExtractLastLanePerPart ||
1130+
return getOpcode() == VPInstruction::ExtractLastLane ||
11321131
getOpcode() == VPInstruction::ExtractPenultimateElement ||
11331132
getOpcode() == Instruction::ExtractElement ||
11341133
getOpcode() == VPInstruction::ExtractLane ||
@@ -1195,8 +1194,8 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
11951194
case VPInstruction::CalculateTripCountMinusVF:
11961195
case VPInstruction::CanonicalIVIncrementForPart:
11971196
case VPInstruction::ExtractLane:
1198-
case VPInstruction::ExtractLastElement:
1199-
case VPInstruction::ExtractLastLanePerPart:
1197+
case VPInstruction::ExtractLastLane:
1198+
case VPInstruction::ExtractLastPart:
12001199
case VPInstruction::ExtractPenultimateElement:
12011200
case VPInstruction::ActiveLaneMask:
12021201
case VPInstruction::ExplicitVectorLength:
@@ -1345,11 +1344,11 @@ void VPInstruction::printRecipe(raw_ostream &O, const Twine &Indent,
13451344
case VPInstruction::ExtractLane:
13461345
O << "extract-lane";
13471346
break;
1348-
case VPInstruction::ExtractLastElement:
1349-
O << "extract-last-element";
1347+
case VPInstruction::ExtractLastLane:
1348+
O << "extract-last-lane";
13501349
break;
1351-
case VPInstruction::ExtractLastLanePerPart:
1352-
O << "extract-last-lane-per-part";
1350+
case VPInstruction::ExtractLastPart:
1351+
O << "extract-last-part";
13531352
break;
13541353
case VPInstruction::ExtractPenultimateElement:
13551354
O << "extract-penultimate-element";
@@ -1502,15 +1501,16 @@ InstructionCost VPIRInstruction::computeCost(ElementCount VF,
15021501
return 0;
15031502
}
15041503

1505-
void VPIRInstruction::extractLastLaneOfFirstOperand(VPBuilder &Builder) {
1504+
void VPIRInstruction::extractFinalLaneOfFirstOperand(VPBuilder &Builder) {
15061505
assert(isa<PHINode>(getInstruction()) &&
15071506
"can only update exiting operands to phi nodes");
15081507
assert(getNumOperands() > 0 && "must have at least one operand");
15091508
VPValue *Exiting = getOperand(0);
15101509
if (Exiting->isLiveIn())
15111510
return;
15121511

1513-
Exiting = Builder.createNaryOp(VPInstruction::ExtractLastElement, {Exiting});
1512+
Exiting = Builder.createNaryOp(VPInstruction::ExtractLastPart, Exiting);
1513+
Exiting = Builder.createNaryOp(VPInstruction::ExtractLastLane, Exiting);
15141514
setOperand(0, Exiting);
15151515
}
15161516

0 commit comments

Comments
 (0)