Skip to content

Commit aca7b3e

Browse files
committed
[VectorCombine] Scalarize bin ops and cmps with two splatted operands
1 parent f2ec705 commit aca7b3e

File tree

3 files changed

+124
-81
lines changed

3 files changed

+124
-81
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 72 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,50 +1035,61 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
10351035
if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
10361036
return false;
10371037

1038-
// Match against one or both scalar values being inserted into constant
1039-
// vectors:
1040-
// vec_op VecC0, (inselt VecC1, V1, Index)
1041-
// vec_op (inselt VecC0, V0, Index), VecC1
1042-
// vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
1043-
// TODO: Deal with mismatched index constants and variable indexes?
10441038
Constant *VecC0 = nullptr, *VecC1 = nullptr;
10451039
Value *V0 = nullptr, *V1 = nullptr;
1046-
uint64_t Index0 = 0, Index1 = 0;
1047-
if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
1048-
m_ConstantInt(Index0))) &&
1049-
!match(Ins0, m_Constant(VecC0)))
1050-
return false;
1051-
if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
1052-
m_ConstantInt(Index1))) &&
1053-
!match(Ins1, m_Constant(VecC1)))
1054-
return false;
1040+
std::optional<uint64_t> Index;
1041+
1042+
// Try and match against two splatted operands first.
1043+
// vec_op (splat V0), (splat V1)
1044+
V0 = getSplatValue(Ins0);
1045+
V1 = getSplatValue(Ins1);
1046+
if (!V0 || !V1) {
1047+
// Match against one or both scalar values being inserted into constant
1048+
// vectors:
1049+
// vec_op VecC0, (inselt VecC1, V1, Index)
1050+
// vec_op (inselt VecC0, V0, Index), VecC1
1051+
// vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
1052+
// TODO: Deal with mismatched index constants and variable indexes?
1053+
V0 = nullptr, V1 = nullptr;
1054+
uint64_t Index0 = 0, Index1 = 0;
1055+
if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
1056+
m_ConstantInt(Index0))) &&
1057+
!match(Ins0, m_Constant(VecC0)))
1058+
return false;
1059+
if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
1060+
m_ConstantInt(Index1))) &&
1061+
!match(Ins1, m_Constant(VecC1)))
1062+
return false;
10551063

1056-
bool IsConst0 = !V0;
1057-
bool IsConst1 = !V1;
1058-
if (IsConst0 && IsConst1)
1059-
return false;
1060-
if (!IsConst0 && !IsConst1 && Index0 != Index1)
1061-
return false;
1064+
bool IsConst0 = !V0;
1065+
bool IsConst1 = !V1;
1066+
if (IsConst0 && IsConst1)
1067+
return false;
1068+
if (!IsConst0 && !IsConst1 && Index0 != Index1)
1069+
return false;
10621070

1063-
auto *VecTy0 = cast<VectorType>(Ins0->getType());
1064-
auto *VecTy1 = cast<VectorType>(Ins1->getType());
1065-
if (VecTy0->getElementCount().getKnownMinValue() <= Index0 ||
1066-
VecTy1->getElementCount().getKnownMinValue() <= Index1)
1067-
return false;
1071+
auto *VecTy0 = cast<VectorType>(Ins0->getType());
1072+
auto *VecTy1 = cast<VectorType>(Ins1->getType());
1073+
if (VecTy0->getElementCount().getKnownMinValue() <= Index0 ||
1074+
VecTy1->getElementCount().getKnownMinValue() <= Index1)
1075+
return false;
10681076

1069-
// Bail for single insertion if it is a load.
1070-
// TODO: Handle this once getVectorInstrCost can cost for load/stores.
1071-
auto *I0 = dyn_cast_or_null<Instruction>(V0);
1072-
auto *I1 = dyn_cast_or_null<Instruction>(V1);
1073-
if ((IsConst0 && I1 && I1->mayReadFromMemory()) ||
1074-
(IsConst1 && I0 && I0->mayReadFromMemory()))
1075-
return false;
1077+
// Bail for single insertion if it is a load.
1078+
// TODO: Handle this once getVectorInstrCost can cost for load/stores.
1079+
auto *I0 = dyn_cast_or_null<Instruction>(V0);
1080+
auto *I1 = dyn_cast_or_null<Instruction>(V1);
1081+
if ((IsConst0 && I1 && I1->mayReadFromMemory()) ||
1082+
(IsConst1 && I0 && I0->mayReadFromMemory()))
1083+
return false;
1084+
1085+
Index = IsConst0 ? Index1 : Index0;
1086+
}
10761087

1077-
uint64_t Index = IsConst0 ? Index1 : Index0;
1078-
Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType();
1079-
Type *VecTy = I.getType();
1088+
auto *VecTy = cast<VectorType>(I.getType());
1089+
Type *ScalarTy = VecTy->getElementType();
10801090
assert(VecTy->isVectorTy() &&
1081-
(IsConst0 || IsConst1 || V0->getType() == V1->getType()) &&
1091+
(isa<Constant>(Ins0) || isa<Constant>(Ins1) ||
1092+
V0->getType() == V1->getType()) &&
10821093
(ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
10831094
ScalarTy->isPointerTy()) &&
10841095
"Unexpected types for insert element into binop or cmp");
@@ -1099,29 +1110,33 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
10991110
// Get cost estimate for the insert element. This cost will factor into
11001111
// both sequences.
11011112
InstructionCost InsertCost = TTI.getVectorInstrCost(
1102-
Instruction::InsertElement, VecTy, CostKind, Index);
1103-
InstructionCost OldCost =
1104-
(IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) + VectorOpCost;
1105-
InstructionCost NewCost = ScalarOpCost + InsertCost +
1106-
(IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) +
1107-
(IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost);
1113+
Instruction::InsertElement, VecTy, CostKind, Index.value_or(0));
1114+
InstructionCost OldCost = (isa<Constant>(Ins0) ? 0 : InsertCost) +
1115+
(isa<Constant>(Ins1) ? 0 : InsertCost) +
1116+
VectorOpCost;
1117+
InstructionCost NewCost =
1118+
ScalarOpCost + InsertCost +
1119+
(isa<Constant>(Ins0) ? 0 : !Ins0->hasOneUse() * InsertCost) +
1120+
(isa<Constant>(Ins1) ? 0 : !Ins1->hasOneUse() * InsertCost);
11081121

11091122
// We want to scalarize unless the vector variant actually has lower cost.
11101123
if (OldCost < NewCost || !NewCost.isValid())
11111124
return false;
11121125

11131126
// vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
11141127
// inselt NewVecC, (scalar_op V0, V1), Index
1128+
//
1129+
// vec_op (splat V0), (splat V1) --> splat (scalar_op V0, V1)
11151130
if (IsCmp)
11161131
++NumScalarCmp;
11171132
else
11181133
++NumScalarBO;
11191134

11201135
// For constant cases, extract the scalar element, this should constant fold.
1121-
if (IsConst0)
1122-
V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index));
1123-
if (IsConst1)
1124-
V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index));
1136+
if (Index && isa<Constant>(Ins0))
1137+
V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(*Index));
1138+
if (Index && isa<Constant>(Ins1))
1139+
V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(*Index));
11251140

11261141
Value *Scalar =
11271142
IsCmp ? Builder.CreateCmp(Pred, V0, V1)
@@ -1134,12 +1149,16 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
11341149
if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
11351150
ScalarInst->copyIRFlags(&I);
11361151

1137-
// Fold the vector constants in the original vectors into a new base vector.
1138-
Value *NewVecC =
1139-
IsCmp ? Builder.CreateCmp(Pred, VecC0, VecC1)
1140-
: Builder.CreateBinOp((Instruction::BinaryOps)Opcode, VecC0, VecC1);
1141-
Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
1142-
replaceValue(I, *Insert);
1152+
Value *Result;
1153+
if (Index) {
1154+
// Fold the vector constants in the original vectors into a new base vector.
1155+
Value *NewVecC = IsCmp ? Builder.CreateCmp(Pred, VecC0, VecC1)
1156+
: Builder.CreateBinOp((Instruction::BinaryOps)Opcode,
1157+
VecC0, VecC1);
1158+
Result = Builder.CreateInsertElement(NewVecC, Scalar, *Index);
1159+
} else
1160+
Result = Builder.CreateVectorSplat(VecTy->getElementCount(), Scalar);
1161+
replaceValue(I, *Result);
11431162
return true;
11441163
}
11451164

llvm/test/Transforms/VectorCombine/X86/shuffle-of-cmps.ll

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,24 @@ define <4 x i32> @shuf_icmp_ugt_v4i32_use(<4 x i32> %x, <4 x i32> %y, <4 x i32>
271271
; PR121110 - don't merge equivalent (but not matching) predicates
272272

273273
define <2 x i1> @PR121110() {
274-
; CHECK-LABEL: define <2 x i1> @PR121110(
275-
; CHECK-SAME: ) #[[ATTR0]] {
276-
; CHECK-NEXT: [[UGT:%.*]] = icmp samesign ugt <2 x i32> zeroinitializer, zeroinitializer
277-
; CHECK-NEXT: [[SGT:%.*]] = icmp sgt <2 x i32> zeroinitializer, <i32 6, i32 -4>
278-
; CHECK-NEXT: [[RES:%.*]] = shufflevector <2 x i1> [[UGT]], <2 x i1> [[SGT]], <2 x i32> <i32 0, i32 3>
279-
; CHECK-NEXT: ret <2 x i1> [[RES]]
274+
; SSE-LABEL: define <2 x i1> @PR121110(
275+
; SSE-SAME: ) #[[ATTR0]] {
276+
; SSE-NEXT: [[SGT:%.*]] = icmp sgt <2 x i32> zeroinitializer, <i32 6, i32 -4>
277+
; SSE-NEXT: [[RES:%.*]] = shufflevector <2 x i1> zeroinitializer, <2 x i1> [[SGT]], <2 x i32> <i32 0, i32 3>
278+
; SSE-NEXT: ret <2 x i1> [[RES]]
279+
;
280+
; AVX2-LABEL: define <2 x i1> @PR121110(
281+
; AVX2-SAME: ) #[[ATTR0]] {
282+
; AVX2-NEXT: [[SGT:%.*]] = icmp sgt <2 x i32> zeroinitializer, <i32 6, i32 -4>
283+
; AVX2-NEXT: [[RES:%.*]] = shufflevector <2 x i1> zeroinitializer, <2 x i1> [[SGT]], <2 x i32> <i32 0, i32 3>
284+
; AVX2-NEXT: ret <2 x i1> [[RES]]
285+
;
286+
; AVX512-LABEL: define <2 x i1> @PR121110(
287+
; AVX512-SAME: ) #[[ATTR0]] {
288+
; AVX512-NEXT: [[UGT:%.*]] = icmp samesign ugt <2 x i32> zeroinitializer, zeroinitializer
289+
; AVX512-NEXT: [[SGT:%.*]] = icmp sgt <2 x i32> zeroinitializer, <i32 6, i32 -4>
290+
; AVX512-NEXT: [[RES:%.*]] = shufflevector <2 x i1> [[UGT]], <2 x i1> [[SGT]], <2 x i32> <i32 0, i32 3>
291+
; AVX512-NEXT: ret <2 x i1> [[RES]]
280292
;
281293
%ugt = icmp samesign ugt <2 x i32> < i32 0, i32 0 >, < i32 0, i32 0 >
282294
%sgt = icmp sgt <2 x i32> < i32 0, i32 0 >, < i32 6, i32 4294967292 >
@@ -285,12 +297,24 @@ define <2 x i1> @PR121110() {
285297
}
286298

287299
define <2 x i1> @PR121110_commute() {
288-
; CHECK-LABEL: define <2 x i1> @PR121110_commute(
289-
; CHECK-SAME: ) #[[ATTR0]] {
290-
; CHECK-NEXT: [[SGT:%.*]] = icmp sgt <2 x i32> zeroinitializer, <i32 6, i32 -4>
291-
; CHECK-NEXT: [[UGT:%.*]] = icmp samesign ugt <2 x i32> zeroinitializer, zeroinitializer
292-
; CHECK-NEXT: [[RES:%.*]] = shufflevector <2 x i1> [[SGT]], <2 x i1> [[UGT]], <2 x i32> <i32 0, i32 3>
293-
; CHECK-NEXT: ret <2 x i1> [[RES]]
300+
; SSE-LABEL: define <2 x i1> @PR121110_commute(
301+
; SSE-SAME: ) #[[ATTR0]] {
302+
; SSE-NEXT: [[SGT:%.*]] = icmp sgt <2 x i32> zeroinitializer, <i32 6, i32 -4>
303+
; SSE-NEXT: [[RES:%.*]] = shufflevector <2 x i1> [[SGT]], <2 x i1> zeroinitializer, <2 x i32> <i32 0, i32 3>
304+
; SSE-NEXT: ret <2 x i1> [[RES]]
305+
;
306+
; AVX2-LABEL: define <2 x i1> @PR121110_commute(
307+
; AVX2-SAME: ) #[[ATTR0]] {
308+
; AVX2-NEXT: [[SGT:%.*]] = icmp sgt <2 x i32> zeroinitializer, <i32 6, i32 -4>
309+
; AVX2-NEXT: [[RES:%.*]] = shufflevector <2 x i1> [[SGT]], <2 x i1> zeroinitializer, <2 x i32> <i32 0, i32 3>
310+
; AVX2-NEXT: ret <2 x i1> [[RES]]
311+
;
312+
; AVX512-LABEL: define <2 x i1> @PR121110_commute(
313+
; AVX512-SAME: ) #[[ATTR0]] {
314+
; AVX512-NEXT: [[SGT:%.*]] = icmp sgt <2 x i32> zeroinitializer, <i32 6, i32 -4>
315+
; AVX512-NEXT: [[UGT:%.*]] = icmp samesign ugt <2 x i32> zeroinitializer, zeroinitializer
316+
; AVX512-NEXT: [[RES:%.*]] = shufflevector <2 x i1> [[SGT]], <2 x i1> [[UGT]], <2 x i32> <i32 0, i32 3>
317+
; AVX512-NEXT: ret <2 x i1> [[RES]]
294318
;
295319
%sgt = icmp sgt <2 x i32> < i32 0, i32 0 >, < i32 6, i32 4294967292 >
296320
%ugt = icmp samesign ugt <2 x i32> < i32 0, i32 0 >, < i32 0, i32 0 >

llvm/test/Transforms/VectorCombine/scalarize-binop.ll

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
define <4 x i32> @add_v4i32(i32 %x, i32 %y) {
55
; CHECK-LABEL: define <4 x i32> @add_v4i32(
66
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
7-
; CHECK-NEXT: [[X_HEAD:%.*]] = insertelement <4 x i32> poison, i32 [[X]], i32 0
7+
; CHECK-NEXT: [[RES_SCALAR:%.*]] = add i32 [[X]], 42
8+
; CHECK-NEXT: [[X_HEAD:%.*]] = insertelement <4 x i32> poison, i32 [[RES_SCALAR]], i64 0
89
; CHECK-NEXT: [[X_SPLAT:%.*]] = shufflevector <4 x i32> [[X_HEAD]], <4 x i32> poison, <4 x i32> zeroinitializer
9-
; CHECK-NEXT: [[RES:%.*]] = add <4 x i32> [[X_SPLAT]], splat (i32 42)
10-
; CHECK-NEXT: ret <4 x i32> [[RES]]
10+
; CHECK-NEXT: ret <4 x i32> [[X_SPLAT]]
1111
;
1212
%x.head = insertelement <4 x i32> poison, i32 %x, i32 0
1313
%x.splat = shufflevector <4 x i32> %x.head, <4 x i32> poison, <4 x i32> zeroinitializer
@@ -18,10 +18,10 @@ define <4 x i32> @add_v4i32(i32 %x, i32 %y) {
1818
define <vscale x 4 x i32> @add_nxv4i32(i32 %x, i32 %y) {
1919
; CHECK-LABEL: define <vscale x 4 x i32> @add_nxv4i32(
2020
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
21-
; CHECK-NEXT: [[Y_HEAD1:%.*]] = insertelement <vscale x 4 x i32> poison, i32 [[X]], i32 0
21+
; CHECK-NEXT: [[RES_SCALAR:%.*]] = add i32 [[X]], 42
22+
; CHECK-NEXT: [[Y_HEAD1:%.*]] = insertelement <vscale x 4 x i32> poison, i32 [[RES_SCALAR]], i64 0
2223
; CHECK-NEXT: [[Y_SPLAT1:%.*]] = shufflevector <vscale x 4 x i32> [[Y_HEAD1]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
23-
; CHECK-NEXT: [[RES:%.*]] = add <vscale x 4 x i32> [[Y_SPLAT1]], splat (i32 42)
24-
; CHECK-NEXT: ret <vscale x 4 x i32> [[RES]]
24+
; CHECK-NEXT: ret <vscale x 4 x i32> [[Y_SPLAT1]]
2525
;
2626
%x.head = insertelement <vscale x 4 x i32> poison, i32 %x, i32 0
2727
%x.splat = shufflevector <vscale x 4 x i32> %x.head, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
@@ -33,11 +33,11 @@ define <vscale x 4 x i32> @add_nxv4i32(i32 %x, i32 %y) {
3333
define <4 x i32> @add_mul_v4i32(i32 %x, i32 %y, i32 %z) {
3434
; CHECK-LABEL: define <4 x i32> @add_mul_v4i32(
3535
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
36-
; CHECK-NEXT: [[Z_HEAD1:%.*]] = insertelement <4 x i32> poison, i32 [[X]], i32 0
36+
; CHECK-NEXT: [[RES0_SCALAR:%.*]] = add i32 [[X]], 42
37+
; CHECK-NEXT: [[RES1_SCALAR:%.*]] = mul i32 [[RES0_SCALAR]], 42
38+
; CHECK-NEXT: [[Z_HEAD1:%.*]] = insertelement <4 x i32> poison, i32 [[RES1_SCALAR]], i64 0
3739
; CHECK-NEXT: [[Z_SPLAT1:%.*]] = shufflevector <4 x i32> [[Z_HEAD1]], <4 x i32> poison, <4 x i32> zeroinitializer
38-
; CHECK-NEXT: [[RES0:%.*]] = add <4 x i32> [[Z_SPLAT1]], splat (i32 42)
39-
; CHECK-NEXT: [[RES1:%.*]] = mul <4 x i32> [[RES0]], splat (i32 42)
40-
; CHECK-NEXT: ret <4 x i32> [[RES1]]
40+
; CHECK-NEXT: ret <4 x i32> [[Z_SPLAT1]]
4141
;
4242
%x.head = insertelement <4 x i32> poison, i32 %x, i32 0
4343
%x.splat = shufflevector <4 x i32> %x.head, <4 x i32> poison, <4 x i32> zeroinitializer
@@ -68,9 +68,9 @@ define <4 x i32> @other_users_v4i32(i32 %x, i32 %y, ptr %p, ptr %q) {
6868
define <4 x i1> @icmp_v4i32(i32 %x, i32 %y) {
6969
; CHECK-LABEL: define <4 x i1> @icmp_v4i32(
7070
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
71-
; CHECK-NEXT: [[X_HEAD:%.*]] = insertelement <4 x i32> poison, i32 [[X]], i32 0
72-
; CHECK-NEXT: [[X_SPLAT:%.*]] = shufflevector <4 x i32> [[X_HEAD]], <4 x i32> poison, <4 x i32> zeroinitializer
73-
; CHECK-NEXT: [[RES:%.*]] = icmp eq <4 x i32> [[X_SPLAT]], splat (i32 42)
71+
; CHECK-NEXT: [[RES_SCALAR:%.*]] = icmp eq i32 [[X]], 42
72+
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <4 x i1> poison, i1 [[RES_SCALAR]], i64 0
73+
; CHECK-NEXT: [[RES:%.*]] = shufflevector <4 x i1> [[DOTSPLATINSERT]], <4 x i1> poison, <4 x i32> zeroinitializer
7474
; CHECK-NEXT: ret <4 x i1> [[RES]]
7575
;
7676
%x.head = insertelement <4 x i32> poison, i32 %x, i32 0
@@ -82,9 +82,9 @@ define <4 x i1> @icmp_v4i32(i32 %x, i32 %y) {
8282
define <vscale x 4 x i1> @icmp_nxv4i32(i32 %x, i32 %y) {
8383
; CHECK-LABEL: define <vscale x 4 x i1> @icmp_nxv4i32(
8484
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
85-
; CHECK-NEXT: [[X_HEAD:%.*]] = insertelement <vscale x 4 x i32> poison, i32 [[X]], i32 0
86-
; CHECK-NEXT: [[X_SPLAT:%.*]] = shufflevector <vscale x 4 x i32> [[X_HEAD]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
87-
; CHECK-NEXT: [[RES:%.*]] = icmp eq <vscale x 4 x i32> [[X_SPLAT]], splat (i32 42)
85+
; CHECK-NEXT: [[RES_SCALAR:%.*]] = icmp eq i32 [[X]], 42
86+
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 4 x i1> poison, i1 [[RES_SCALAR]], i64 0
87+
; CHECK-NEXT: [[RES:%.*]] = shufflevector <vscale x 4 x i1> [[DOTSPLATINSERT]], <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer
8888
; CHECK-NEXT: ret <vscale x 4 x i1> [[RES]]
8989
;
9090
%x.head = insertelement <vscale x 4 x i32> poison, i32 %x, i32 0

0 commit comments

Comments
 (0)