Skip to content

Commit f9f7969

Browse files
committed
Reassociate header mask
1 parent 4c91627 commit f9f7969

File tree

4 files changed

+87
-81
lines changed

4 files changed

+87
-81
lines changed

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 59 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -996,7 +996,8 @@ static Value *tryToFoldLiveIns(const VPRecipeBase &R, unsigned Opcode,
996996
}
997997

998998
/// Try to simplify recipe \p R.
999-
static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
999+
static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo,
1000+
VPValue *HeaderMask) {
10001001
VPlan *Plan = R.getParent()->getPlan();
10011002

10021003
auto *Def = dyn_cast<VPSingleDefRecipe>(&R);
@@ -1119,6 +1120,14 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
11191120
return;
11201121
}
11211122

1123+
// Reassociate the header mask so it has more opportunities to be simplified.
1124+
// (headermask && x) && y -> headermask && (x && y)
1125+
if (HeaderMask && match(Def, m_LogicalAnd(m_LogicalAnd(m_Specific(HeaderMask),
1126+
m_VPValue(X)),
1127+
m_VPValue(Y))))
1128+
return Def->replaceAllUsesWith(
1129+
Builder.createLogicalAnd(HeaderMask, Builder.createLogicalAnd(X, Y)));
1130+
11221131
if (match(Def, m_c_Mul(m_VPValue(A), m_SpecificInt(1))))
11231132
return Def->replaceAllUsesWith(A);
11241133

@@ -1263,13 +1272,61 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
12631272
}
12641273
}
12651274

1275+
/// Collect the header mask with the pattern:
1276+
/// (ICMP_ULE, WideCanonicalIV, backedge-taken-count)
1277+
/// TODO: Introduce explicit recipe for header-mask instead of searching
1278+
/// for the header-mask pattern manually.
1279+
static VPSingleDefRecipe *findHeaderMask(VPlan &Plan) {
1280+
SmallVector<VPValue *> WideCanonicalIVs;
1281+
auto *FoundWidenCanonicalIVUser = find_if(Plan.getCanonicalIV()->users(),
1282+
IsaPred<VPWidenCanonicalIVRecipe>);
1283+
assert(count_if(Plan.getCanonicalIV()->users(),
1284+
IsaPred<VPWidenCanonicalIVRecipe>) <= 1 &&
1285+
"Must have at most one VPWideCanonicalIVRecipe");
1286+
if (FoundWidenCanonicalIVUser != Plan.getCanonicalIV()->users().end()) {
1287+
auto *WideCanonicalIV =
1288+
cast<VPWidenCanonicalIVRecipe>(*FoundWidenCanonicalIVUser);
1289+
WideCanonicalIVs.push_back(WideCanonicalIV);
1290+
}
1291+
1292+
// Also include VPWidenIntOrFpInductionRecipes that represent a widened
1293+
// version of the canonical induction.
1294+
VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock();
1295+
for (VPRecipeBase &Phi : HeaderVPBB->phis()) {
1296+
auto *WidenOriginalIV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&Phi);
1297+
if (WidenOriginalIV && WidenOriginalIV->isCanonical())
1298+
WideCanonicalIVs.push_back(WidenOriginalIV);
1299+
}
1300+
1301+
// Walk users of wide canonical IVs and find the single compare of the form
1302+
// (ICMP_ULE, WideCanonicalIV, backedge-taken-count).
1303+
VPSingleDefRecipe *HeaderMask = nullptr;
1304+
for (auto *Wide : WideCanonicalIVs) {
1305+
for (VPUser *U : SmallVector<VPUser *>(Wide->users())) {
1306+
auto *VPI = dyn_cast<VPInstruction>(U);
1307+
if (!VPI || !vputils::isHeaderMask(VPI, Plan))
1308+
continue;
1309+
1310+
assert(VPI->getOperand(0) == Wide &&
1311+
"WidenCanonicalIV must be the first operand of the compare");
1312+
assert(!HeaderMask && "Multiple header masks found?");
1313+
HeaderMask = VPI;
1314+
}
1315+
}
1316+
return HeaderMask;
1317+
}
1318+
12661319
void VPlanTransforms::simplifyRecipes(VPlan &Plan) {
1320+
VPValue *HeaderMask = nullptr;
1321+
// Ignore post-unrolling as there can be multiple header masks.
1322+
if (!Plan.isUnrolled())
1323+
HeaderMask = findHeaderMask(Plan);
12671324
ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT(
12681325
Plan.getEntry());
12691326
VPTypeAnalysis TypeInfo(Plan);
12701327
for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(RPOT)) {
12711328
for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
1272-
simplifyRecipe(R, TypeInfo);
1329+
simplifyRecipe(R, TypeInfo, HeaderMask);
12731330
}
12741331
}
12751332
}
@@ -2192,50 +2249,6 @@ static VPActiveLaneMaskPHIRecipe *addVPLaneMaskPhiAndUpdateExitBranch(
21922249
return LaneMaskPhi;
21932250
}
21942251

2195-
/// Collect the header mask with the pattern:
2196-
/// (ICMP_ULE, WideCanonicalIV, backedge-taken-count)
2197-
/// TODO: Introduce explicit recipe for header-mask instead of searching
2198-
/// for the header-mask pattern manually.
2199-
static VPSingleDefRecipe *findHeaderMask(VPlan &Plan) {
2200-
SmallVector<VPValue *> WideCanonicalIVs;
2201-
auto *FoundWidenCanonicalIVUser = find_if(Plan.getCanonicalIV()->users(),
2202-
IsaPred<VPWidenCanonicalIVRecipe>);
2203-
assert(count_if(Plan.getCanonicalIV()->users(),
2204-
IsaPred<VPWidenCanonicalIVRecipe>) <= 1 &&
2205-
"Must have at most one VPWideCanonicalIVRecipe");
2206-
if (FoundWidenCanonicalIVUser != Plan.getCanonicalIV()->users().end()) {
2207-
auto *WideCanonicalIV =
2208-
cast<VPWidenCanonicalIVRecipe>(*FoundWidenCanonicalIVUser);
2209-
WideCanonicalIVs.push_back(WideCanonicalIV);
2210-
}
2211-
2212-
// Also include VPWidenIntOrFpInductionRecipes that represent a widened
2213-
// version of the canonical induction.
2214-
VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock();
2215-
for (VPRecipeBase &Phi : HeaderVPBB->phis()) {
2216-
auto *WidenOriginalIV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&Phi);
2217-
if (WidenOriginalIV && WidenOriginalIV->isCanonical())
2218-
WideCanonicalIVs.push_back(WidenOriginalIV);
2219-
}
2220-
2221-
// Walk users of wide canonical IVs and find the single compare of the form
2222-
// (ICMP_ULE, WideCanonicalIV, backedge-taken-count).
2223-
VPSingleDefRecipe *HeaderMask = nullptr;
2224-
for (auto *Wide : WideCanonicalIVs) {
2225-
for (VPUser *U : SmallVector<VPUser *>(Wide->users())) {
2226-
auto *VPI = dyn_cast<VPInstruction>(U);
2227-
if (!VPI || !vputils::isHeaderMask(VPI, Plan))
2228-
continue;
2229-
2230-
assert(VPI->getOperand(0) == Wide &&
2231-
"WidenCanonicalIV must be the first operand of the compare");
2232-
assert(!HeaderMask && "Multiple header masks found?");
2233-
HeaderMask = VPI;
2234-
}
2235-
}
2236-
return HeaderMask;
2237-
}
2238-
22392252
void VPlanTransforms::addActiveLaneMask(
22402253
VPlan &Plan, bool UseActiveLaneMaskForControlFlow,
22412254
bool DataAndControlFlowWithoutRuntimeCheck) {

llvm/test/Transforms/LoopVectorize/RISCV/blocks-with-dead-instructions.ll

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -436,23 +436,17 @@ define void @multiple_blocks_with_dead_inst_multiple_successors_6(ptr %src, i1 %
436436
; CHECK-NEXT: [[VEC_IND:%.*]] = phi <vscale x 8 x i64> [ [[INDUCTION]], %[[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], %[[VECTOR_BODY]] ]
437437
; CHECK-NEXT: [[AVL:%.*]] = phi i64 [ [[TMP2]], %[[VECTOR_PH]] ], [ [[AVL_NEXT:%.*]], %[[VECTOR_BODY]] ]
438438
; CHECK-NEXT: [[TMP27:%.*]] = call i32 @llvm.experimental.get.vector.length.i64(i64 [[AVL]], i32 8, i1 true)
439-
; CHECK-NEXT: [[BROADCAST_SPLATINSERT3:%.*]] = insertelement <vscale x 8 x i32> poison, i32 [[TMP27]], i64 0
440-
; CHECK-NEXT: [[BROADCAST_SPLAT4:%.*]] = shufflevector <vscale x 8 x i32> [[BROADCAST_SPLATINSERT3]], <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
441439
; CHECK-NEXT: [[TMP12:%.*]] = zext i32 [[TMP27]] to i64
442440
; CHECK-NEXT: [[TMP16:%.*]] = mul i64 3, [[TMP12]]
443441
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 8 x i64> poison, i64 [[TMP16]], i64 0
444442
; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <vscale x 8 x i64> [[DOTSPLATINSERT]], <vscale x 8 x i64> poison, <vscale x 8 x i32> zeroinitializer
445-
; CHECK-NEXT: [[TMP14:%.*]] = call <vscale x 8 x i32> @llvm.stepvector.nxv8i32()
446-
; CHECK-NEXT: [[TMP15:%.*]] = icmp ult <vscale x 8 x i32> [[TMP14]], [[BROADCAST_SPLAT4]]
447443
; CHECK-NEXT: [[TMP20:%.*]] = getelementptr i16, ptr [[SRC]], <vscale x 8 x i64> [[VEC_IND]]
448444
; CHECK-NEXT: [[WIDE_MASKED_GATHER:%.*]] = call <vscale x 8 x i16> @llvm.vp.gather.nxv8i16.nxv8p0(<vscale x 8 x ptr> align 2 [[TMP20]], <vscale x 8 x i1> splat (i1 true), i32 [[TMP27]])
449445
; CHECK-NEXT: [[TMP17:%.*]] = icmp eq <vscale x 8 x i16> [[WIDE_MASKED_GATHER]], zeroinitializer
450-
; CHECK-NEXT: [[TMP18:%.*]] = select <vscale x 8 x i1> [[TMP15]], <vscale x 8 x i1> [[TMP17]], <vscale x 8 x i1> zeroinitializer
451-
; CHECK-NEXT: [[TMP19:%.*]] = select <vscale x 8 x i1> [[TMP18]], <vscale x 8 x i1> [[TMP8]], <vscale x 8 x i1> zeroinitializer
446+
; CHECK-NEXT: [[TMP29:%.*]] = select <vscale x 8 x i1> [[TMP17]], <vscale x 8 x i1> [[TMP8]], <vscale x 8 x i1> zeroinitializer
452447
; CHECK-NEXT: [[TMP28:%.*]] = xor <vscale x 8 x i1> [[TMP17]], splat (i1 true)
453-
; CHECK-NEXT: [[TMP21:%.*]] = select <vscale x 8 x i1> [[TMP15]], <vscale x 8 x i1> [[TMP28]], <vscale x 8 x i1> zeroinitializer
454-
; CHECK-NEXT: [[TMP22:%.*]] = or <vscale x 8 x i1> [[TMP19]], [[TMP21]]
455-
; CHECK-NEXT: [[TMP23:%.*]] = select <vscale x 8 x i1> [[TMP18]], <vscale x 8 x i1> [[BROADCAST_SPLAT]], <vscale x 8 x i1> zeroinitializer
448+
; CHECK-NEXT: [[TMP22:%.*]] = or <vscale x 8 x i1> [[TMP29]], [[TMP28]]
449+
; CHECK-NEXT: [[TMP23:%.*]] = select <vscale x 8 x i1> [[TMP17]], <vscale x 8 x i1> [[BROADCAST_SPLAT]], <vscale x 8 x i1> zeroinitializer
456450
; CHECK-NEXT: [[TMP24:%.*]] = or <vscale x 8 x i1> [[TMP22]], [[TMP23]]
457451
; CHECK-NEXT: call void @llvm.vp.scatter.nxv8i16.nxv8p0(<vscale x 8 x i16> zeroinitializer, <vscale x 8 x ptr> align 2 [[TMP20]], <vscale x 8 x i1> [[TMP24]], i32 [[TMP27]])
458452
; CHECK-NEXT: [[TMP25:%.*]] = zext i32 [[TMP27]] to i64

llvm/test/Transforms/LoopVectorize/RISCV/pr87378-vpinstruction-or-drop-poison-generating-flags.ll

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,16 @@ define void @pr87378_vpinstruction_or_drop_poison_generating_flags(ptr %arg, i64
3434
; CHECK-NEXT: [[TMP10:%.*]] = call <vscale x 8 x i32> @llvm.stepvector.nxv8i32()
3535
; CHECK-NEXT: [[TMP11:%.*]] = icmp ult <vscale x 8 x i32> [[TMP10]], [[BROADCAST_SPLAT8]]
3636
; CHECK-NEXT: [[TMP13:%.*]] = icmp ule <vscale x 8 x i64> [[VEC_IND]], [[BROADCAST_SPLAT]]
37-
; CHECK-NEXT: [[TMP28:%.*]] = select <vscale x 8 x i1> [[TMP11]], <vscale x 8 x i1> [[TMP13]], <vscale x 8 x i1> zeroinitializer
3837
; CHECK-NEXT: [[TMP14:%.*]] = icmp ule <vscale x 8 x i64> [[VEC_IND]], [[BROADCAST_SPLAT2]]
39-
; CHECK-NEXT: [[TMP15:%.*]] = select <vscale x 8 x i1> [[TMP28]], <vscale x 8 x i1> [[TMP14]], <vscale x 8 x i1> zeroinitializer
38+
; CHECK-NEXT: [[TMP9:%.*]] = select <vscale x 8 x i1> [[TMP13]], <vscale x 8 x i1> [[TMP14]], <vscale x 8 x i1> zeroinitializer
4039
; CHECK-NEXT: [[TMP16:%.*]] = xor <vscale x 8 x i1> [[TMP13]], splat (i1 true)
41-
; CHECK-NEXT: [[TMP29:%.*]] = select <vscale x 8 x i1> [[TMP11]], <vscale x 8 x i1> [[TMP16]], <vscale x 8 x i1> zeroinitializer
42-
; CHECK-NEXT: [[TMP17:%.*]] = or <vscale x 8 x i1> [[TMP15]], [[TMP29]]
40+
; CHECK-NEXT: [[TMP17:%.*]] = or <vscale x 8 x i1> [[TMP9]], [[TMP16]]
4341
; CHECK-NEXT: [[TMP18:%.*]] = icmp ule <vscale x 8 x i64> [[VEC_IND]], [[BROADCAST_SPLAT4]]
4442
; CHECK-NEXT: [[TMP19:%.*]] = select <vscale x 8 x i1> [[TMP17]], <vscale x 8 x i1> [[TMP18]], <vscale x 8 x i1> zeroinitializer
4543
; CHECK-NEXT: [[TMP20:%.*]] = xor <vscale x 8 x i1> [[TMP14]], splat (i1 true)
46-
; CHECK-NEXT: [[TMP21:%.*]] = select <vscale x 8 x i1> [[TMP28]], <vscale x 8 x i1> [[TMP20]], <vscale x 8 x i1> zeroinitializer
47-
; CHECK-NEXT: [[TMP22:%.*]] = or <vscale x 8 x i1> [[TMP19]], [[TMP21]]
44+
; CHECK-NEXT: [[TMP28:%.*]] = select <vscale x 8 x i1> [[TMP13]], <vscale x 8 x i1> [[TMP20]], <vscale x 8 x i1> zeroinitializer
45+
; CHECK-NEXT: [[TMP21:%.*]] = select <vscale x 8 x i1> [[TMP11]], <vscale x 8 x i1> [[TMP28]], <vscale x 8 x i1> zeroinitializer
46+
; CHECK-NEXT: [[TMP22:%.*]] = or <vscale x 8 x i1> [[TMP19]], [[TMP28]]
4847
; CHECK-NEXT: [[TMP23:%.*]] = extractelement <vscale x 8 x i1> [[TMP21]], i32 0
4948
; CHECK-NEXT: [[PREDPHI:%.*]] = select i1 [[TMP23]], i64 poison, i64 [[INDEX]]
5049
; CHECK-NEXT: [[TMP24:%.*]] = getelementptr i16, ptr [[ARG]], i64 [[PREDPHI]]

0 commit comments

Comments
 (0)