Skip to content

Commit 7f54fcc

Browse files
authored
[VPlan] Add ExtractLastLanePerPart, use in narrowToSingleScalar. (#163056)
When narrowing stores of a single-scalar, we currently use ExtractLastElement, which extracts the last element across all parts. This is not correct if the store's address is not uniform across all parts. If it is only uniform-per-part, the last lane per part must be extracted. Add a new ExtractLastLanePerPart opcode to handle this correctly. Most transforms apply to both ExtractLastElement and ExtractLastLanePerPart, with the only difference being their treatment during unrolling. Fixes llvm/llvm-project#162498. PR: llvm/llvm-project#163056
1 parent 55bd6fb commit 7f54fcc

File tree

7 files changed

+48
-14
lines changed

7 files changed

+48
-14
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,8 @@ class LLVM_ABI_FOR_TEST VPInstruction : public VPRecipeWithIRFlags,
10121012
// part if scalar. In the latter case, the recipe will be removed during
10131013
// unrolling.
10141014
ExtractLastElement,
1015+
// Extracts the last lane for each part from its operand.
1016+
ExtractLastLanePerPart,
10151017
// Extracts the second-to-last lane from its operand or the second-to-last
10161018
// part if it is scalar. In the latter case, the recipe will be removed
10171019
// during unrolling.

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
116116
case VPInstruction::FirstActiveLane:
117117
return Type::getIntNTy(Ctx, 64);
118118
case VPInstruction::ExtractLastElement:
119+
case VPInstruction::ExtractLastLanePerPart:
119120
case VPInstruction::ExtractPenultimateElement: {
120121
Type *BaseTy = inferScalarType(R->getOperand(0));
121122
if (auto *VecTy = dyn_cast<VectorType>(BaseTy))

llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,12 @@ m_ExtractLastElement(const Op0_t &Op0) {
372372
return m_VPInstruction<VPInstruction::ExtractLastElement>(Op0);
373373
}
374374

375+
template <typename Op0_t>
376+
inline VPInstruction_match<VPInstruction::ExtractLastLanePerPart, Op0_t>
377+
m_ExtractLastLanePerPart(const Op0_t &Op0) {
378+
return m_VPInstruction<VPInstruction::ExtractLastLanePerPart>(Op0);
379+
}
380+
375381
template <typename Op0_t, typename Op1_t, typename Op2_t>
376382
inline VPInstruction_match<VPInstruction::ActiveLaneMask, Op0_t, Op1_t, Op2_t>
377383
m_ActiveLaneMask(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@ unsigned VPInstruction::getNumOperandsForOpcode(unsigned Opcode) {
511511
case VPInstruction::CanonicalIVIncrementForPart:
512512
case VPInstruction::ExplicitVectorLength:
513513
case VPInstruction::ExtractLastElement:
514+
case VPInstruction::ExtractLastLanePerPart:
514515
case VPInstruction::ExtractPenultimateElement:
515516
case VPInstruction::FirstActiveLane:
516517
case VPInstruction::Not:
@@ -878,9 +879,11 @@ Value *VPInstruction::generate(VPTransformState &State) {
878879

879880
return ReducedPartRdx;
880881
}
882+
case VPInstruction::ExtractLastLanePerPart:
881883
case VPInstruction::ExtractLastElement:
882884
case VPInstruction::ExtractPenultimateElement: {
883-
unsigned Offset = getOpcode() == VPInstruction::ExtractLastElement ? 1 : 2;
885+
unsigned Offset =
886+
getOpcode() == VPInstruction::ExtractPenultimateElement ? 2 : 1;
884887
Value *Res;
885888
if (State.VF.isVector()) {
886889
assert(Offset <= State.VF.getKnownMinValue() &&
@@ -1166,6 +1169,7 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
11661169

11671170
bool VPInstruction::isVectorToScalar() const {
11681171
return getOpcode() == VPInstruction::ExtractLastElement ||
1172+
getOpcode() == VPInstruction::ExtractLastLanePerPart ||
11691173
getOpcode() == VPInstruction::ExtractPenultimateElement ||
11701174
getOpcode() == Instruction::ExtractElement ||
11711175
getOpcode() == VPInstruction::ExtractLane ||
@@ -1229,6 +1233,7 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
12291233
case VPInstruction::CanonicalIVIncrementForPart:
12301234
case VPInstruction::ExtractLane:
12311235
case VPInstruction::ExtractLastElement:
1236+
case VPInstruction::ExtractLastLanePerPart:
12321237
case VPInstruction::ExtractPenultimateElement:
12331238
case VPInstruction::ActiveLaneMask:
12341239
case VPInstruction::FirstActiveLane:
@@ -1376,6 +1381,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
13761381
case VPInstruction::ExtractLastElement:
13771382
O << "extract-last-element";
13781383
break;
1384+
case VPInstruction::ExtractLastLanePerPart:
1385+
O << "extract-last-lane-per-part";
1386+
break;
13791387
case VPInstruction::ExtractPenultimateElement:
13801388
O << "extract-penultimate-element";
13811389
break;

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,7 +1209,8 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
12091209
}
12101210

12111211
// Look through ExtractLastElement (BuildVector ....).
1212-
if (match(&R, m_ExtractLastElement(m_BuildVector()))) {
1212+
if (match(&R, m_CombineOr(m_ExtractLastElement(m_BuildVector()),
1213+
m_ExtractLastLanePerPart(m_BuildVector())))) {
12131214
auto *BuildVector = cast<VPInstruction>(R.getOperand(0));
12141215
Def->replaceAllUsesWith(
12151216
BuildVector->getOperand(BuildVector->getNumOperands() - 1));
@@ -1275,20 +1276,28 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
12751276
return;
12761277
}
12771278

1278-
if (match(Def, m_ExtractLastElement(m_Broadcast(m_VPValue(A))))) {
1279+
if (match(Def,
1280+
m_CombineOr(m_ExtractLastElement(m_Broadcast(m_VPValue(A))),
1281+
m_ExtractLastLanePerPart(m_Broadcast(m_VPValue(A)))))) {
12791282
Def->replaceAllUsesWith(A);
12801283
return;
12811284
}
12821285

1283-
if (match(Def,
1284-
m_VPInstruction<VPInstruction::ExtractLastElement>(m_VPValue(A))) &&
1286+
if (match(Def, m_CombineOr(m_ExtractLastElement(m_VPValue(A)),
1287+
m_ExtractLastLanePerPart(m_VPValue(A)))) &&
12851288
((isa<VPInstruction>(A) && vputils::isSingleScalar(A)) ||
12861289
(isa<VPReplicateRecipe>(A) &&
12871290
cast<VPReplicateRecipe>(A)->isSingleScalar())) &&
12881291
all_of(A->users(),
12891292
[Def, A](VPUser *U) { return U->usesScalars(A) || Def == U; })) {
12901293
return Def->replaceAllUsesWith(A);
12911294
}
1295+
1296+
if (Plan->getUF() == 1 &&
1297+
match(Def, m_ExtractLastLanePerPart(m_VPValue(A)))) {
1298+
return Def->replaceAllUsesWith(
1299+
Builder.createNaryOp(VPInstruction::ExtractLastElement, {A}));
1300+
}
12921301
}
12931302

12941303
void VPlanTransforms::simplifyRecipes(VPlan &Plan) {
@@ -1326,8 +1335,11 @@ static void narrowToSingleScalarRecipes(VPlan &Plan) {
13261335
RepOrWidenR->getUnderlyingInstr(), RepOrWidenR->operands(),
13271336
true /*IsSingleScalar*/, nullptr /*Mask*/, *RepR /*Metadata*/);
13281337
Clone->insertBefore(RepOrWidenR);
1329-
auto *Ext = new VPInstruction(VPInstruction::ExtractLastElement,
1330-
{Clone->getOperand(0)});
1338+
unsigned ExtractOpc =
1339+
vputils::isUniformAcrossVFsAndUFs(RepR->getOperand(1))
1340+
? VPInstruction::ExtractLastElement
1341+
: VPInstruction::ExtractLastLanePerPart;
1342+
auto *Ext = new VPInstruction(ExtractOpc, {Clone->getOperand(0)});
13311343
Ext->insertBefore(Clone);
13321344
Clone->setOperand(0, Ext);
13331345
RepR->eraseFromParent();
@@ -1341,7 +1353,8 @@ static void narrowToSingleScalarRecipes(VPlan &Plan) {
13411353
!all_of(RepOrWidenR->users(), [RepOrWidenR](const VPUser *U) {
13421354
return U->usesScalars(RepOrWidenR) ||
13431355
match(cast<VPRecipeBase>(U),
1344-
m_ExtractLastElement(m_VPValue()));
1356+
m_CombineOr(m_ExtractLastElement(m_VPValue()),
1357+
m_ExtractLastLanePerPart(m_VPValue())));
13451358
}))
13461359
continue;
13471360

llvm/test/Transforms/LoopVectorize/AArch64/replicating-load-store-costs.ll

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,17 +153,20 @@ define void @uniform_gep_for_replicating_gep(ptr %dst) {
153153
; CHECK-NEXT: [[VEC_IND:%.*]] = phi <2 x i32> [ <i32 0, i32 1>, %[[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], %[[VECTOR_BODY]] ]
154154
; CHECK-NEXT: [[STEP_ADD:%.*]] = add <2 x i32> [[VEC_IND]], splat (i32 2)
155155
; CHECK-NEXT: [[TMP2:%.*]] = add i32 [[INDEX]], 2
156-
; CHECK-NEXT: [[TMP5:%.*]] = icmp eq <2 x i32> [[STEP_ADD]], zeroinitializer
156+
; CHECK-NEXT: [[TMP5:%.*]] = icmp eq <2 x i32> [[VEC_IND]], zeroinitializer
157+
; CHECK-NEXT: [[TMP3:%.*]] = icmp eq <2 x i32> [[STEP_ADD]], zeroinitializer
157158
; CHECK-NEXT: [[TMP8:%.*]] = lshr i32 [[INDEX]], 1
158159
; CHECK-NEXT: [[TMP9:%.*]] = lshr i32 [[TMP2]], 1
159160
; CHECK-NEXT: [[TMP11:%.*]] = zext <2 x i1> [[TMP5]] to <2 x i8>
161+
; CHECK-NEXT: [[TMP6:%.*]] = zext <2 x i1> [[TMP3]] to <2 x i8>
160162
; CHECK-NEXT: [[TMP14:%.*]] = zext i32 [[TMP8]] to i64
161163
; CHECK-NEXT: [[TMP15:%.*]] = zext i32 [[TMP9]] to i64
162164
; CHECK-NEXT: [[TMP18:%.*]] = getelementptr i64, ptr [[DST]], i64 [[TMP14]]
163165
; CHECK-NEXT: [[TMP19:%.*]] = getelementptr i64, ptr [[DST]], i64 [[TMP15]]
164166
; CHECK-NEXT: [[TMP22:%.*]] = extractelement <2 x i8> [[TMP11]], i32 1
167+
; CHECK-NEXT: [[TMP12:%.*]] = extractelement <2 x i8> [[TMP6]], i32 1
165168
; CHECK-NEXT: store i8 [[TMP22]], ptr [[TMP18]], align 1
166-
; CHECK-NEXT: store i8 [[TMP22]], ptr [[TMP19]], align 1
169+
; CHECK-NEXT: store i8 [[TMP12]], ptr [[TMP19]], align 1
167170
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
168171
; CHECK-NEXT: [[VEC_IND_NEXT]] = add <2 x i32> [[STEP_ADD]], splat (i32 2)
169172
; CHECK-NEXT: [[TMP24:%.*]] = icmp eq i32 [[INDEX_NEXT]], 128

llvm/test/Transforms/LoopVectorize/narrow-to-single-scalar.ll

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ exit:
7474
ret void
7575
}
7676

77-
; FIXME: Currently this mis-compiled when interleaving; all stores store the
78-
; last lane of the last part, instead of the last lane per part.
77+
; Check each unrolled store stores the last lane of the corresponding part.
7978
; Test case for https://github.com/llvm/llvm-project/issues/162498.
8079
define void @narrow_to_single_scalar_store_address_not_uniform_across_all_parts(ptr %dst) {
8180
; VF4IC1-LABEL: define void @narrow_to_single_scalar_store_address_not_uniform_across_all_parts(
@@ -121,13 +120,15 @@ define void @narrow_to_single_scalar_store_address_not_uniform_across_all_parts(
121120
; VF2IC2-NEXT: br label %[[VECTOR_BODY:.*]]
122121
; VF2IC2: [[VECTOR_BODY]]:
123122
; VF2IC2-NEXT: [[INDEX:%.*]] = phi i32 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
123+
; VF2IC2-NEXT: [[TMP7:%.*]] = add i32 [[INDEX]], 0
124+
; VF2IC2-NEXT: [[TMP8:%.*]] = add i32 [[INDEX]], 1
124125
; VF2IC2-NEXT: [[TMP0:%.*]] = add i32 [[INDEX]], 2
125126
; VF2IC2-NEXT: [[TMP1:%.*]] = add i32 [[INDEX]], 3
126-
; VF2IC2-NEXT: [[TMP2:%.*]] = lshr i32 [[INDEX]], 1
127+
; VF2IC2-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP7]], 1
127128
; VF2IC2-NEXT: [[TMP3:%.*]] = lshr i32 [[TMP0]], 1
128129
; VF2IC2-NEXT: [[TMP4:%.*]] = getelementptr i32, ptr [[DST]], i32 [[TMP2]]
129130
; VF2IC2-NEXT: [[TMP5:%.*]] = getelementptr i32, ptr [[DST]], i32 [[TMP3]]
130-
; VF2IC2-NEXT: store i32 [[TMP1]], ptr [[TMP4]], align 4
131+
; VF2IC2-NEXT: store i32 [[TMP8]], ptr [[TMP4]], align 4
131132
; VF2IC2-NEXT: store i32 [[TMP1]], ptr [[TMP5]], align 4
132133
; VF2IC2-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
133134
; VF2IC2-NEXT: [[TMP6:%.*]] = icmp eq i32 [[INDEX_NEXT]], 100

0 commit comments

Comments
 (0)