Skip to content

Commit 2b7b8bd

Browse files
authored
[X86] Accept the canonical form of a sign bit test in MatchVectorAllEqualTest. (#154421)
This function tries to look for (seteq (and (reduce_or), mask), 0). If the mask is a sign bit, InstCombine will have turned it into (setgt (reduce_or), -1). We should handle that case too. I'm looking into adding the same canonicalization to SimplifySetCC and this change is needed to prevent test regressions.
1 parent 562e021 commit 2b7b8bd

File tree

2 files changed

+145
-46
lines changed

2 files changed

+145
-46
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23185,43 +23185,51 @@ static SDValue LowerVectorAllEqual(const SDLoc &DL, SDValue LHS, SDValue RHS,
2318523185

2318623186
// Check whether an AND/OR'd reduction tree is PTEST-able, or if we can fallback
2318723187
// to CMP(MOVMSK(PCMPEQB(X,Y))).
23188-
static SDValue MatchVectorAllEqualTest(SDValue LHS, SDValue RHS,
23188+
static SDValue MatchVectorAllEqualTest(SDValue OrigLHS, SDValue OrigRHS,
2318923189
ISD::CondCode CC, const SDLoc &DL,
2319023190
const X86Subtarget &Subtarget,
2319123191
SelectionDAG &DAG,
2319223192
X86::CondCode &X86CC) {
23193-
assert((CC == ISD::SETEQ || CC == ISD::SETNE) && "Unsupported ISD::CondCode");
23193+
SDValue Op = OrigLHS;
2319423194

23195-
bool CmpNull = isNullConstant(RHS);
23196-
bool CmpAllOnes = isAllOnesConstant(RHS);
23197-
if (!CmpNull && !CmpAllOnes)
23198-
return SDValue();
23195+
bool CmpNull;
23196+
APInt Mask;
23197+
if (CC == ISD::SETEQ || CC == ISD::SETNE) {
23198+
CmpNull = isNullConstant(OrigRHS);
23199+
if (!CmpNull && !isAllOnesConstant(OrigRHS))
23200+
return SDValue();
2319923201

23200-
SDValue Op = LHS;
23201-
if (!Subtarget.hasSSE2() || !Op->hasOneUse())
23202-
return SDValue();
23202+
if (!Subtarget.hasSSE2() || !Op->hasOneUse())
23203+
return SDValue();
2320323204

23204-
// Check whether we're masking/truncating an OR-reduction result, in which
23205-
// case track the masked bits.
23206-
// TODO: Add CmpAllOnes support.
23207-
APInt Mask = APInt::getAllOnes(Op.getScalarValueSizeInBits());
23208-
if (CmpNull) {
23209-
switch (Op.getOpcode()) {
23210-
case ISD::TRUNCATE: {
23211-
SDValue Src = Op.getOperand(0);
23212-
Mask = APInt::getLowBitsSet(Src.getScalarValueSizeInBits(),
23213-
Op.getScalarValueSizeInBits());
23214-
Op = Src;
23215-
break;
23216-
}
23217-
case ISD::AND: {
23218-
if (auto *Cst = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
23219-
Mask = Cst->getAPIntValue();
23220-
Op = Op.getOperand(0);
23205+
// Check whether we're masking/truncating an OR-reduction result, in which
23206+
// case track the masked bits.
23207+
// TODO: Add CmpAllOnes support.
23208+
Mask = APInt::getAllOnes(Op.getScalarValueSizeInBits());
23209+
if (CmpNull) {
23210+
switch (Op.getOpcode()) {
23211+
case ISD::TRUNCATE: {
23212+
SDValue Src = Op.getOperand(0);
23213+
Mask = APInt::getLowBitsSet(Src.getScalarValueSizeInBits(),
23214+
Op.getScalarValueSizeInBits());
23215+
Op = Src;
23216+
break;
23217+
}
23218+
case ISD::AND: {
23219+
if (auto *Cst = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
23220+
Mask = Cst->getAPIntValue();
23221+
Op = Op.getOperand(0);
23222+
}
23223+
break;
23224+
}
2322123225
}
23222-
break;
23223-
}
2322423226
}
23227+
} else if (CC == ISD::SETGT && isAllOnesConstant(OrigRHS)) {
23228+
CC = ISD::SETEQ;
23229+
CmpNull = true;
23230+
Mask = APInt::getSignMask(Op.getScalarValueSizeInBits());
23231+
} else {
23232+
return SDValue();
2322523233
}
2322623234

2322723235
ISD::NodeType LogicOp = CmpNull ? ISD::OR : ISD::AND;
@@ -56274,14 +56282,16 @@ static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG,
5627456282
if (SDValue V = combineVectorSizedSetCCEquality(VT, LHS, RHS, CC, DL, DAG,
5627556283
Subtarget))
5627656284
return V;
56285+
}
5627756286

56278-
if (VT == MVT::i1) {
56279-
X86::CondCode X86CC;
56280-
if (SDValue V =
56281-
MatchVectorAllEqualTest(LHS, RHS, CC, DL, Subtarget, DAG, X86CC))
56282-
return DAG.getNode(ISD::TRUNCATE, DL, VT, getSETCC(X86CC, V, DL, DAG));
56283-
}
56287+
if (VT == MVT::i1) {
56288+
X86::CondCode X86CC;
56289+
if (SDValue V =
56290+
MatchVectorAllEqualTest(LHS, RHS, CC, DL, Subtarget, DAG, X86CC))
56291+
return DAG.getNode(ISD::TRUNCATE, DL, VT, getSETCC(X86CC, V, DL, DAG));
56292+
}
5628456293

56294+
if (CC == ISD::SETNE || CC == ISD::SETEQ) {
5628556295
if (OpVT.isScalarInteger()) {
5628656296
// cmpeq(or(X,Y),X) --> cmpeq(and(~X,Y),0)
5628756297
// cmpne(or(X,Y),X) --> cmpne(and(~X,Y),0)

llvm/test/CodeGen/X86/vector-reduce-or-cmp.ll

Lines changed: 101 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,95 @@ define i1 @mask_v8i32(<8 x i32> %a0) {
903903
ret i1 %3
904904
}
905905

906+
define i1 @mask_v8i32_2(<8 x i32> %a0) {
907+
; SSE2-LABEL: mask_v8i32_2:
908+
; SSE2: # %bb.0:
909+
; SSE2-NEXT: por %xmm1, %xmm0
910+
; SSE2-NEXT: pslld $1, %xmm0
911+
; SSE2-NEXT: movmskps %xmm0, %eax
912+
; SSE2-NEXT: testl %eax, %eax
913+
; SSE2-NEXT: sete %al
914+
; SSE2-NEXT: retq
915+
;
916+
; SSE41-LABEL: mask_v8i32_2:
917+
; SSE41: # %bb.0:
918+
; SSE41-NEXT: por %xmm1, %xmm0
919+
; SSE41-NEXT: ptest {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
920+
; SSE41-NEXT: sete %al
921+
; SSE41-NEXT: retq
922+
;
923+
; AVX1-LABEL: mask_v8i32_2:
924+
; AVX1: # %bb.0:
925+
; AVX1-NEXT: vptest {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0
926+
; AVX1-NEXT: sete %al
927+
; AVX1-NEXT: vzeroupper
928+
; AVX1-NEXT: retq
929+
;
930+
; AVX2-LABEL: mask_v8i32_2:
931+
; AVX2: # %bb.0:
932+
; AVX2-NEXT: vpbroadcastq {{.*#+}} ymm1 = [4611686019501129728,4611686019501129728,4611686019501129728,4611686019501129728]
933+
; AVX2-NEXT: vptest %ymm1, %ymm0
934+
; AVX2-NEXT: sete %al
935+
; AVX2-NEXT: vzeroupper
936+
; AVX2-NEXT: retq
937+
;
938+
; AVX512-LABEL: mask_v8i32_2:
939+
; AVX512: # %bb.0:
940+
; AVX512-NEXT: vpbroadcastq {{.*#+}} ymm1 = [4611686019501129728,4611686019501129728,4611686019501129728,4611686019501129728]
941+
; AVX512-NEXT: vptest %ymm1, %ymm0
942+
; AVX512-NEXT: sete %al
943+
; AVX512-NEXT: vzeroupper
944+
; AVX512-NEXT: retq
945+
%1 = call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> %a0)
946+
%2 = and i32 %1, 1073741824
947+
%3 = icmp eq i32 %2, 0
948+
ret i1 %3
949+
}
950+
951+
952+
define i1 @signtest_v8i32(<8 x i32> %a0) {
953+
; SSE2-LABEL: signtest_v8i32:
954+
; SSE2: # %bb.0:
955+
; SSE2-NEXT: orps %xmm1, %xmm0
956+
; SSE2-NEXT: movmskps %xmm0, %eax
957+
; SSE2-NEXT: testl %eax, %eax
958+
; SSE2-NEXT: sete %al
959+
; SSE2-NEXT: retq
960+
;
961+
; SSE41-LABEL: signtest_v8i32:
962+
; SSE41: # %bb.0:
963+
; SSE41-NEXT: por %xmm1, %xmm0
964+
; SSE41-NEXT: ptest {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
965+
; SSE41-NEXT: sete %al
966+
; SSE41-NEXT: retq
967+
;
968+
; AVX1-LABEL: signtest_v8i32:
969+
; AVX1: # %bb.0:
970+
; AVX1-NEXT: vptest {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0
971+
; AVX1-NEXT: sete %al
972+
; AVX1-NEXT: vzeroupper
973+
; AVX1-NEXT: retq
974+
;
975+
; AVX2-LABEL: signtest_v8i32:
976+
; AVX2: # %bb.0:
977+
; AVX2-NEXT: vpbroadcastq {{.*#+}} ymm1 = [9223372039002259456,9223372039002259456,9223372039002259456,9223372039002259456]
978+
; AVX2-NEXT: vptest %ymm1, %ymm0
979+
; AVX2-NEXT: sete %al
980+
; AVX2-NEXT: vzeroupper
981+
; AVX2-NEXT: retq
982+
;
983+
; AVX512-LABEL: signtest_v8i32:
984+
; AVX512: # %bb.0:
985+
; AVX512-NEXT: vpbroadcastq {{.*#+}} ymm1 = [9223372039002259456,9223372039002259456,9223372039002259456,9223372039002259456]
986+
; AVX512-NEXT: vptest %ymm1, %ymm0
987+
; AVX512-NEXT: sete %al
988+
; AVX512-NEXT: vzeroupper
989+
; AVX512-NEXT: retq
990+
%1 = call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> %a0)
991+
%2 = icmp sgt i32 %1, -1
992+
ret i1 %2
993+
}
994+
906995
define i1 @trunc_v16i16(<16 x i16> %a0) {
907996
; SSE2-LABEL: trunc_v16i16:
908997
; SSE2: # %bb.0:
@@ -1073,11 +1162,11 @@ define i32 @mask_v3i1(<3 x i32> %a, <3 x i32> %b) {
10731162
; SSE2-NEXT: movd %xmm0, %eax
10741163
; SSE2-NEXT: orl %ecx, %eax
10751164
; SSE2-NEXT: testb $1, %al
1076-
; SSE2-NEXT: je .LBB27_2
1165+
; SSE2-NEXT: je .LBB29_2
10771166
; SSE2-NEXT: # %bb.1:
10781167
; SSE2-NEXT: xorl %eax, %eax
10791168
; SSE2-NEXT: retq
1080-
; SSE2-NEXT: .LBB27_2:
1169+
; SSE2-NEXT: .LBB29_2:
10811170
; SSE2-NEXT: movl $1, %eax
10821171
; SSE2-NEXT: retq
10831172
;
@@ -1092,11 +1181,11 @@ define i32 @mask_v3i1(<3 x i32> %a, <3 x i32> %b) {
10921181
; SSE41-NEXT: pextrd $2, %xmm1, %eax
10931182
; SSE41-NEXT: orl %ecx, %eax
10941183
; SSE41-NEXT: testb $1, %al
1095-
; SSE41-NEXT: je .LBB27_2
1184+
; SSE41-NEXT: je .LBB29_2
10961185
; SSE41-NEXT: # %bb.1:
10971186
; SSE41-NEXT: xorl %eax, %eax
10981187
; SSE41-NEXT: retq
1099-
; SSE41-NEXT: .LBB27_2:
1188+
; SSE41-NEXT: .LBB29_2:
11001189
; SSE41-NEXT: movl $1, %eax
11011190
; SSE41-NEXT: retq
11021191
;
@@ -1111,11 +1200,11 @@ define i32 @mask_v3i1(<3 x i32> %a, <3 x i32> %b) {
11111200
; AVX1OR2-NEXT: vpextrd $2, %xmm0, %eax
11121201
; AVX1OR2-NEXT: orl %ecx, %eax
11131202
; AVX1OR2-NEXT: testb $1, %al
1114-
; AVX1OR2-NEXT: je .LBB27_2
1203+
; AVX1OR2-NEXT: je .LBB29_2
11151204
; AVX1OR2-NEXT: # %bb.1:
11161205
; AVX1OR2-NEXT: xorl %eax, %eax
11171206
; AVX1OR2-NEXT: retq
1118-
; AVX1OR2-NEXT: .LBB27_2:
1207+
; AVX1OR2-NEXT: .LBB29_2:
11191208
; AVX1OR2-NEXT: movl $1, %eax
11201209
; AVX1OR2-NEXT: retq
11211210
;
@@ -1130,12 +1219,12 @@ define i32 @mask_v3i1(<3 x i32> %a, <3 x i32> %b) {
11301219
; AVX512F-NEXT: korw %k0, %k1, %k0
11311220
; AVX512F-NEXT: kmovw %k0, %eax
11321221
; AVX512F-NEXT: testb $1, %al
1133-
; AVX512F-NEXT: je .LBB27_2
1222+
; AVX512F-NEXT: je .LBB29_2
11341223
; AVX512F-NEXT: # %bb.1:
11351224
; AVX512F-NEXT: xorl %eax, %eax
11361225
; AVX512F-NEXT: vzeroupper
11371226
; AVX512F-NEXT: retq
1138-
; AVX512F-NEXT: .LBB27_2:
1227+
; AVX512F-NEXT: .LBB29_2:
11391228
; AVX512F-NEXT: movl $1, %eax
11401229
; AVX512F-NEXT: vzeroupper
11411230
; AVX512F-NEXT: retq
@@ -1151,12 +1240,12 @@ define i32 @mask_v3i1(<3 x i32> %a, <3 x i32> %b) {
11511240
; AVX512BW-NEXT: korw %k0, %k1, %k0
11521241
; AVX512BW-NEXT: kmovd %k0, %eax
11531242
; AVX512BW-NEXT: testb $1, %al
1154-
; AVX512BW-NEXT: je .LBB27_2
1243+
; AVX512BW-NEXT: je .LBB29_2
11551244
; AVX512BW-NEXT: # %bb.1:
11561245
; AVX512BW-NEXT: xorl %eax, %eax
11571246
; AVX512BW-NEXT: vzeroupper
11581247
; AVX512BW-NEXT: retq
1159-
; AVX512BW-NEXT: .LBB27_2:
1248+
; AVX512BW-NEXT: .LBB29_2:
11601249
; AVX512BW-NEXT: movl $1, %eax
11611250
; AVX512BW-NEXT: vzeroupper
11621251
; AVX512BW-NEXT: retq
@@ -1170,11 +1259,11 @@ define i32 @mask_v3i1(<3 x i32> %a, <3 x i32> %b) {
11701259
; AVX512BWVL-NEXT: korw %k0, %k1, %k0
11711260
; AVX512BWVL-NEXT: kmovd %k0, %eax
11721261
; AVX512BWVL-NEXT: testb $1, %al
1173-
; AVX512BWVL-NEXT: je .LBB27_2
1262+
; AVX512BWVL-NEXT: je .LBB29_2
11741263
; AVX512BWVL-NEXT: # %bb.1:
11751264
; AVX512BWVL-NEXT: xorl %eax, %eax
11761265
; AVX512BWVL-NEXT: retq
1177-
; AVX512BWVL-NEXT: .LBB27_2:
1266+
; AVX512BWVL-NEXT: .LBB29_2:
11781267
; AVX512BWVL-NEXT: movl $1, %eax
11791268
; AVX512BWVL-NEXT: retq
11801269
%1 = icmp ne <3 x i32> %a, %b

0 commit comments

Comments
 (0)