Skip to content

Commit 9b92c8d

Browse files
committed
[VPlan] Replace ExtractLast(Elem|LanePerPart) with ExtractLast(Lane/Part)
Replace ExtractLastElement and ExtractLastLanePerPart with more generic and specific ExtractLastLane and ExtractLastPart, which model distinct parts of extracting across parts and lanes. ExtractLastElement == ExtractLastLane(ExtractLastPart) and ExtractLastLanePerPart == ExtractLastLane, the latter clarifying the name of the opcode. A new m_ExtractLastElement matcher is provided for convenience. The patch should be NFC modulo printing changes.
1 parent 792c65c commit 9b92c8d

14 files changed

+161
-116
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,12 +1017,10 @@ class LLVM_ABI_FOR_TEST VPInstruction : public VPRecipeWithIRFlags,
10171017
ComputeAnyOfResult,
10181018
ComputeFindIVResult,
10191019
ComputeReductionResult,
1020-
// Extracts the last lane from its operand if it is a vector, or the last
1021-
// part if scalar. In the latter case, the recipe will be removed during
1022-
// unrolling.
1023-
ExtractLastElement,
1024-
// Extracts the last lane for each part from its operand.
1025-
ExtractLastLanePerPart,
1020+
// Extracts the last part of its operand.
1021+
ExtractLastPart,
1022+
// Extracts the last lane of the current part of its operand.
1023+
ExtractLastLane,
10261024
// Extracts the second-to-last lane from its operand or the second-to-last
10271025
// part if it is scalar. In the latter case, the recipe will be removed
10281026
// during unrolling.

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,17 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
116116
return inferScalarType(R->getOperand(1));
117117
case VPInstruction::FirstActiveLane:
118118
return Type::getIntNTy(Ctx, 64);
119-
case VPInstruction::ExtractLastElement:
120-
case VPInstruction::ExtractLastLanePerPart:
119+
case VPInstruction::ExtractLastLane:
121120
case VPInstruction::ExtractPenultimateElement: {
122121
Type *BaseTy = inferScalarType(R->getOperand(0));
123122
if (auto *VecTy = dyn_cast<VectorType>(BaseTy))
124123
return VecTy->getElementType();
125124
return BaseTy;
126125
}
126+
case VPInstruction::ExtractLastPart: {
127+
// ExtractLastPart returns the same type as its operand
128+
return inferScalarType(R->getOperand(0));
129+
}
127130
case VPInstruction::LogicalAnd:
128131
assert(inferScalarType(R->getOperand(0))->isIntegerTy(1) &&
129132
inferScalarType(R->getOperand(1))->isIntegerTy(1) &&

llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -383,9 +383,9 @@ m_EVL(const Op0_t &Op0) {
383383
}
384384

385385
template <typename Op0_t>
386-
inline VPInstruction_match<VPInstruction::ExtractLastElement, Op0_t>
387-
m_ExtractLastElement(const Op0_t &Op0) {
388-
return m_VPInstruction<VPInstruction::ExtractLastElement>(Op0);
386+
inline VPInstruction_match<VPInstruction::ExtractLastLane, Op0_t>
387+
m_ExtractLastLane(const Op0_t &Op0) {
388+
return m_VPInstruction<VPInstruction::ExtractLastLane>(Op0);
389389
}
390390

391391
template <typename Op0_t, typename Op1_t>
@@ -395,9 +395,17 @@ m_ExtractElement(const Op0_t &Op0, const Op1_t &Op1) {
395395
}
396396

397397
template <typename Op0_t>
398-
inline VPInstruction_match<VPInstruction::ExtractLastLanePerPart, Op0_t>
399-
m_ExtractLastLanePerPart(const Op0_t &Op0) {
400-
return m_VPInstruction<VPInstruction::ExtractLastLanePerPart>(Op0);
398+
inline VPInstruction_match<VPInstruction::ExtractLastPart, Op0_t>
399+
m_ExtractLastPart(const Op0_t &Op0) {
400+
return m_VPInstruction<VPInstruction::ExtractLastPart>(Op0);
401+
}
402+
403+
template <typename Op0_t>
404+
inline VPInstruction_match<
405+
VPInstruction::ExtractLastLane,
406+
VPInstruction_match<VPInstruction::ExtractLastPart, Op0_t>>
407+
m_ExtractLastElement(const Op0_t &Op0) {
408+
return m_ExtractLastLane(m_ExtractLastPart(Op0));
401409
}
402410

403411
template <typename Op0_t, typename Op1_t, typename Op2_t>

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -520,8 +520,8 @@ unsigned VPInstruction::getNumOperandsForOpcode(unsigned Opcode) {
520520
case VPInstruction::CalculateTripCountMinusVF:
521521
case VPInstruction::CanonicalIVIncrementForPart:
522522
case VPInstruction::ExplicitVectorLength:
523-
case VPInstruction::ExtractLastElement:
524-
case VPInstruction::ExtractLastLanePerPart:
523+
case VPInstruction::ExtractLastLane:
524+
case VPInstruction::ExtractLastPart:
525525
case VPInstruction::ExtractPenultimateElement:
526526
case VPInstruction::FirstActiveLane:
527527
case VPInstruction::Not:
@@ -890,8 +890,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
890890

891891
return ReducedPartRdx;
892892
}
893-
case VPInstruction::ExtractLastLanePerPart:
894-
case VPInstruction::ExtractLastElement:
893+
case VPInstruction::ExtractLastLane:
895894
case VPInstruction::ExtractPenultimateElement: {
896895
unsigned Offset =
897896
getOpcode() == VPInstruction::ExtractPenultimateElement ? 2 : 1;
@@ -1159,7 +1158,7 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
11591158
I32Ty, {Arg0Ty, I32Ty, I1Ty});
11601159
return Ctx.TTI.getIntrinsicInstrCost(Attrs, Ctx.CostKind);
11611160
}
1162-
case VPInstruction::ExtractLastElement: {
1161+
case VPInstruction::ExtractLastLane: {
11631162
// Add on the cost of extracting the element.
11641163
auto *VecTy = toVectorTy(Ctx.Types.inferScalarType(getOperand(0)), VF);
11651164
return Ctx.TTI.getIndexedVectorInstrCostFromEnd(Instruction::ExtractElement,
@@ -1179,8 +1178,7 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
11791178
}
11801179

11811180
bool VPInstruction::isVectorToScalar() const {
1182-
return getOpcode() == VPInstruction::ExtractLastElement ||
1183-
getOpcode() == VPInstruction::ExtractLastLanePerPart ||
1181+
return getOpcode() == VPInstruction::ExtractLastLane ||
11841182
getOpcode() == VPInstruction::ExtractPenultimateElement ||
11851183
getOpcode() == Instruction::ExtractElement ||
11861184
getOpcode() == VPInstruction::ExtractLane ||
@@ -1243,8 +1241,8 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
12431241
case VPInstruction::CalculateTripCountMinusVF:
12441242
case VPInstruction::CanonicalIVIncrementForPart:
12451243
case VPInstruction::ExtractLane:
1246-
case VPInstruction::ExtractLastElement:
1247-
case VPInstruction::ExtractLastLanePerPart:
1244+
case VPInstruction::ExtractLastLane:
1245+
case VPInstruction::ExtractLastPart:
12481246
case VPInstruction::ExtractPenultimateElement:
12491247
case VPInstruction::ActiveLaneMask:
12501248
case VPInstruction::FirstActiveLane:
@@ -1391,11 +1389,11 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
13911389
case VPInstruction::ExtractLane:
13921390
O << "extract-lane";
13931391
break;
1394-
case VPInstruction::ExtractLastElement:
1395-
O << "extract-last-element";
1392+
case VPInstruction::ExtractLastLane:
1393+
O << "extract-last-lane";
13961394
break;
1397-
case VPInstruction::ExtractLastLanePerPart:
1398-
O << "extract-last-lane-per-part";
1395+
case VPInstruction::ExtractLastPart:
1396+
O << "extract-last-part";
13991397
break;
14001398
case VPInstruction::ExtractPenultimateElement:
14011399
O << "extract-penultimate-element";
@@ -1558,7 +1556,8 @@ void VPIRInstruction::extractLastLaneOfFirstOperand(VPBuilder &Builder) {
15581556
if (Exiting->isLiveIn())
15591557
return;
15601558

1561-
Exiting = Builder.createNaryOp(VPInstruction::ExtractLastElement, {Exiting});
1559+
Exiting = Builder.createNaryOp(VPInstruction::ExtractLastPart, Exiting);
1560+
Exiting = Builder.createNaryOp(VPInstruction::ExtractLastLane, Exiting);
15621561
setOperand(0, Exiting);
15631562
}
15641563

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,9 +1238,8 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
12381238
return;
12391239
}
12401240

1241-
// Look through ExtractLastElement (BuildVector ....).
1242-
if (match(&R, m_CombineOr(m_ExtractLastElement(m_BuildVector()),
1243-
m_ExtractLastLanePerPart(m_BuildVector())))) {
1241+
// Look through ExtractLastLane (BuildVector ....).
1242+
if (match(&R, m_ExtractLastLane(m_BuildVector()))) {
12441243
auto *BuildVector = cast<VPInstruction>(R.getOperand(0));
12451244
Def->replaceAllUsesWith(
12461245
BuildVector->getOperand(BuildVector->getNumOperands() - 1));
@@ -1313,15 +1312,12 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
13131312
return;
13141313
}
13151314

1316-
if (match(Def,
1317-
m_CombineOr(m_ExtractLastElement(m_Broadcast(m_VPValue(A))),
1318-
m_ExtractLastLanePerPart(m_Broadcast(m_VPValue(A)))))) {
1315+
if (match(Def, m_ExtractLastLane(m_Broadcast(m_VPValue(A))))) {
13191316
Def->replaceAllUsesWith(A);
13201317
return;
13211318
}
13221319

1323-
if (match(Def, m_CombineOr(m_ExtractLastElement(m_VPValue(A)),
1324-
m_ExtractLastLanePerPart(m_VPValue(A)))) &&
1320+
if (match(Def, m_ExtractLastLane(m_VPValue(A))) &&
13251321
((isa<VPInstruction>(A) && vputils::isSingleScalar(A)) ||
13261322
(isa<VPReplicateRecipe>(A) &&
13271323
cast<VPReplicateRecipe>(A)->isSingleScalar())) &&
@@ -1330,11 +1326,8 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
13301326
return Def->replaceAllUsesWith(A);
13311327
}
13321328

1333-
if (Plan->getUF() == 1 &&
1334-
match(Def, m_ExtractLastLanePerPart(m_VPValue(A)))) {
1335-
return Def->replaceAllUsesWith(
1336-
Builder.createNaryOp(VPInstruction::ExtractLastElement, {A}));
1337-
}
1329+
if (Plan->getUF() == 1 && match(Def, m_ExtractLastPart(m_VPValue(A))))
1330+
return Def->replaceAllUsesWith(A);
13381331
}
13391332

13401333
void VPlanTransforms::simplifyRecipes(VPlan &Plan) {
@@ -1372,13 +1365,14 @@ static void narrowToSingleScalarRecipes(VPlan &Plan) {
13721365
RepOrWidenR->getUnderlyingInstr(), RepOrWidenR->operands(),
13731366
true /*IsSingleScalar*/, nullptr /*Mask*/, *RepR /*Metadata*/);
13741367
Clone->insertBefore(RepOrWidenR);
1375-
unsigned ExtractOpc =
1376-
vputils::isUniformAcrossVFsAndUFs(RepR->getOperand(1))
1377-
? VPInstruction::ExtractLastElement
1378-
: VPInstruction::ExtractLastLanePerPart;
1379-
auto *Ext = new VPInstruction(ExtractOpc, {Clone->getOperand(0)});
1380-
Ext->insertBefore(Clone);
1381-
Clone->setOperand(0, Ext);
1368+
VPBuilder Builder(Clone);
1369+
VPValue *ExtractOp = Clone->getOperand(0);
1370+
if (vputils::isUniformAcrossVFsAndUFs(RepR->getOperand(1)))
1371+
ExtractOp =
1372+
Builder.createNaryOp(VPInstruction::ExtractLastPart, ExtractOp);
1373+
ExtractOp =
1374+
Builder.createNaryOp(VPInstruction::ExtractLastLane, ExtractOp);
1375+
Clone->setOperand(0, ExtractOp);
13821376
RepR->eraseFromParent();
13831377
continue;
13841378
}
@@ -1389,9 +1383,7 @@ static void narrowToSingleScalarRecipes(VPlan &Plan) {
13891383
if (!vputils::isSingleScalar(RepOrWidenR) ||
13901384
!all_of(RepOrWidenR->users(), [RepOrWidenR](const VPUser *U) {
13911385
return U->usesScalars(RepOrWidenR) ||
1392-
match(cast<VPRecipeBase>(U),
1393-
m_CombineOr(m_ExtractLastElement(m_VPValue()),
1394-
m_ExtractLastLanePerPart(m_VPValue())));
1386+
match(cast<VPRecipeBase>(U), m_ExtractLastPart(m_VPValue()));
13951387
}))
13961388
continue;
13971389

@@ -4412,10 +4404,13 @@ void VPlanTransforms::addScalarResumePhis(
44124404
auto *ResumeFromVectorLoop = VectorPhiR->getBackedgeValue();
44134405
assert(VectorRegion->getSingleSuccessor() == Plan.getMiddleBlock() &&
44144406
"Cannot handle loops with uncountable early exits");
4415-
if (IsFOR)
4416-
ResumeFromVectorLoop = MiddleBuilder.createNaryOp(
4417-
VPInstruction::ExtractLastElement, {ResumeFromVectorLoop}, {},
4418-
"vector.recur.extract");
4407+
if (IsFOR) {
4408+
auto *ExtractPart = MiddleBuilder.createNaryOp(
4409+
VPInstruction::ExtractLastPart, ResumeFromVectorLoop);
4410+
ResumeFromVectorLoop =
4411+
MiddleBuilder.createNaryOp(VPInstruction::ExtractLastLane,
4412+
ExtractPart, {}, "vector.recur.extract");
4413+
}
44194414
StringRef Name = IsFOR ? "scalar.recur.init" : "bc.merge.rdx";
44204415
auto *ResumePhiR = ScalarPHBuilder.createScalarPhi(
44214416
{ResumeFromVectorLoop, VectorPhiR->getStartValue()}, {}, Name);
@@ -4513,10 +4508,11 @@ void VPlanTransforms::addExitUsersForFirstOrderRecurrences(VPlan &Plan,
45134508
// Now update VPIRInstructions modeling LCSSA phis in the exit block.
45144509
// Extract the penultimate value of the recurrence and use it as operand for
45154510
// the VPIRInstruction modeling the phi.
4516-
for (VPUser *U : FOR->users()) {
4517-
using namespace llvm::VPlanPatternMatch;
4518-
if (!match(U, m_ExtractLastElement(m_Specific(FOR))))
4511+
for (VPRecipeBase &R : make_early_inc_range(
4512+
make_range(MiddleVPBB->getFirstNonPhi(), MiddleVPBB->end()))) {
4513+
if (!match(&R, m_ExtractLastElement(m_Specific(FOR))))
45194514
continue;
4515+
45204516
// For VF vscale x 1, if vscale = 1, we are unable to extract the
45214517
// penultimate value of the recurrence. Instead we rely on the existing
45224518
// extract of the last element from the result of
@@ -4526,9 +4522,11 @@ void VPlanTransforms::addExitUsersForFirstOrderRecurrences(VPlan &Plan,
45264522
Range))
45274523
return;
45284524
VPValue *PenultimateElement = MiddleBuilder.createNaryOp(
4529-
VPInstruction::ExtractPenultimateElement, {FOR->getBackedgeValue()},
4525+
VPInstruction::ExtractPenultimateElement,
4526+
MiddleBuilder.createNaryOp(VPInstruction::ExtractLastPart,
4527+
FOR->getBackedgeValue()),
45304528
{}, "vector.recur.extract.for.phi");
4531-
cast<VPInstruction>(U)->replaceAllUsesWith(PenultimateElement);
4529+
cast<VPInstruction>(&R)->replaceAllUsesWith(PenultimateElement);
45324530
}
45334531
}
45344532
}

llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -372,22 +372,27 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
372372
R.addOperand(getValueForPart(Op1, Part));
373373
continue;
374374
}
375-
if (match(&R, m_ExtractLastElement(m_VPValue(Op0))) ||
376-
match(&R, m_VPInstruction<VPInstruction::ExtractPenultimateElement>(
377-
m_VPValue(Op0)))) {
378-
addUniformForAllParts(cast<VPSingleDefRecipe>(&R));
375+
376+
// Handle extraction from the last part. For scalar VF, directly replace
377+
// with the appropriate scalar part. Otherwise, update operand to use the
378+
// part.
379+
if (match(&R, m_VPInstruction<VPInstruction::ExtractPenultimateElement>(
380+
m_ExtractLastPart(m_VPValue(Op0)))) ||
381+
match(&R, m_ExtractLastElement(m_VPValue(Op0)))) {
382+
auto *I = cast<VPInstruction>(&R);
383+
bool IsPenultimate =
384+
I->getOpcode() == VPInstruction::ExtractPenultimateElement;
385+
unsigned PartIdx = IsPenultimate ? UF - 2 : UF - 1;
386+
379387
if (Plan.hasScalarVFOnly()) {
380-
auto *I = cast<VPInstruction>(&R);
381-
// Extracting from end with VF = 1 implies retrieving the last or
382-
// penultimate scalar part (UF-1 or UF-2).
383-
unsigned Offset =
384-
I->getOpcode() == VPInstruction::ExtractLastElement ? 1 : 2;
385-
I->replaceAllUsesWith(getValueForPart(Op0, UF - Offset));
386-
R.eraseFromParent();
387-
} else {
388-
// Otherwise we extract from the last part.
389-
remapOperands(&R, UF - 1);
388+
// For scalar VF, directly use the scalar part value.
389+
addUniformForAllParts(I);
390+
I->replaceAllUsesWith(getValueForPart(Op0, PartIdx));
391+
continue;
390392
}
393+
// For vector VF, extract from the last part.
394+
addUniformForAllParts(I);
395+
R.setOperand(0, getValueForPart(Op0, UF - 1));
391396
continue;
392397
}
393398

@@ -491,12 +496,10 @@ cloneForLane(VPlan &Plan, VPBuilder &Builder, Type *IdxTy,
491496
continue;
492497
}
493498
if (Lane.getKind() == VPLane::Kind::ScalableLast) {
494-
// Look through mandatory Unpack.
495-
[[maybe_unused]] bool Matched =
496-
match(Op, m_VPInstruction<VPInstruction::Unpack>(m_VPValue(Op)));
497-
assert(Matched && "original op must have been Unpack");
499+
auto *ExtractPart =
500+
Builder.createNaryOp(VPInstruction::ExtractLastPart, {Op});
498501
NewOps.push_back(
499-
Builder.createNaryOp(VPInstruction::ExtractLastElement, {Op}));
502+
Builder.createNaryOp(VPInstruction::ExtractLastLane, {ExtractPart}));
500503
continue;
501504
}
502505
if (vputils::isSingleScalar(Op)) {

llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ define i32 @print_partial_reduction(ptr %a, ptr %b) {
3939
; CHECK-NEXT: Successor(s): middle.block
4040
; CHECK-EMPTY:
4141
; CHECK-NEXT: middle.block:
42-
; CHECK-NEXT: EMIT vp<[[RED_RESULT:%.+]]> = compute-reduction-result ir<[[ACC]]>, vp<[[REDUCE]]>
42+
; CHECK-NEXT: EMIT vp<[[RED_RESULT_PART:%.+]]> = compute-reduction-result ir<[[ACC]]>, vp<[[REDUCE]]>
43+
; CHECK-NEXT: EMIT vp<[[RED_RESULT_PART2:%.+]]> = extract-last-part vp<[[RED_RESULT_PART]]>
44+
; CHECK-NEXT: EMIT vp<[[RED_RESULT:%.+]]> = extract-last-lane vp<[[RED_RESULT_PART2]]>
4345
; CHECK-NEXT: EMIT vp<[[CMP:%.+]]> = icmp eq ir<1024>, vp<[[VEC_TC]]>
4446
; CHECK-NEXT: EMIT branch-on-cond vp<[[CMP]]>
4547
; CHECK-NEXT: Successor(s): ir-bb<exit>, scalar.ph
@@ -50,7 +52,7 @@ define i32 @print_partial_reduction(ptr %a, ptr %b) {
5052
; CHECK-EMPTY:
5153
; CHECK-NEXT: scalar.ph:
5254
; CHECK-NEXT: EMIT-SCALAR vp<%bc.resume.val> = phi [ vp<[[VEC_TC]]>, middle.block ], [ ir<0>, ir-bb<entry> ]
53-
; CHECK-NEXT: EMIT-SCALAR vp<%bc.merge.rdx> = phi [ vp<[[RED_RESULT]]>, middle.block ], [ ir<0>, ir-bb<entry> ]
55+
; CHECK-NEXT: EMIT-SCALAR vp<%bc.merge.rdx> = phi [ vp<[[RED_RESULT_PART]]>, middle.block ], [ ir<0>, ir-bb<entry> ]
5456
; CHECK-NEXT: Successor(s): ir-bb<for.body>
5557
; CHECK-EMPTY:
5658
; CHECK-NEXT: ir-bb<for.body>:

0 commit comments

Comments
 (0)