-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[X86] Combine PTEST to TESTP
#157249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[X86] Combine PTEST to TESTP
#157249
Conversation
Combine `PTEST` to `TESTP` if only sign bit tested. Discovered in llvm#156233
|
@llvm/pr-subscribers-backend-x86 Author: Abhishek Kaushik (abhishek-kaushik22) ChangesCombine Discovered in #156233 Full diff: https://github.com/llvm/llvm-project/pull/157249.diff 3 Files Affected:
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index ab21cf534b304..c95fd00828bcf 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -48624,6 +48624,45 @@ static SDValue combineCarryThroughADD(SDValue EFLAGS, SelectionDAG &DAG) {
return SDValue();
}
+static SDValue canFoldToTESTP(SDValue Val, const SDLoc &DL, const EVT PTestVT,
+ SelectionDAG &DAG,
+ const X86Subtarget &Subtarget) {
+ if (!Subtarget.hasAVX())
+ return SDValue();
+
+ EVT VT = Val.getValueType();
+ unsigned EltBits = VT.getScalarSizeInBits();
+
+ if (EltBits != 32 && EltBits != 64)
+ return SDValue();
+
+ SDValue Op0 = Val.getOperand(0);
+ SDValue Op1 = Val.getOperand(1);
+
+ MVT FloatSVT = MVT::getFloatingPointVT(EltBits);
+ MVT FloatVT = MVT::getVectorVT(FloatSVT, VT.getVectorNumElements());
+
+ // (ptest (and Op0, splat(minSignedVal)), (and Op0, splat(minSignedVal))) ->
+ // (testp Op0, Op0)
+ APInt Splat;
+ if (ISD::isConstantSplatVector(Op1.getNode(), Splat) &&
+ Splat.getBitWidth() == EltBits && Splat.isMinSignedValue()) {
+ SDValue FpOp0 = DAG.getBitcast(FloatVT, Op0);
+ return DAG.getNode(X86ISD::TESTP, DL, PTestVT, FpOp0, FpOp0);
+ }
+
+ // (ptest (and (and Op0, splat(minSignedVal), Op1), ...)) -> (testp Op0, Op1)
+ if (Op0.getOpcode() == ISD::AND &&
+ ISD::isConstantSplatVector(Op0.getOperand(1).getNode(), Splat) &&
+ Splat.getBitWidth() == EltBits && Splat.isMinSignedValue()) {
+ SDValue FpOp0 = DAG.getBitcast(FloatVT, Op0.getOperand(0));
+ SDValue FpOp1 = DAG.getBitcast(FloatVT, Op1);
+ return DAG.getNode(X86ISD::TESTP, DL, PTestVT, FpOp0, FpOp1);
+ }
+
+ return SDValue();
+}
+
/// If we are inverting an PTEST/TESTP operand, attempt to adjust the CC
/// to avoid the inversion.
static SDValue combinePTESTCC(SDValue EFLAGS, X86::CondCode &CC,
@@ -48718,6 +48757,10 @@ static SDValue combinePTESTCC(SDValue EFLAGS, X86::CondCode &CC,
SDValue BC = peekThroughBitcasts(Op0);
EVT BCVT = BC.getValueType();
+ if (EFLAGS.getOpcode() == X86ISD::PTEST && BC.getOpcode() == ISD::AND)
+ if (SDValue V = canFoldToTESTP(BC, SDLoc(EFLAGS), VT, DAG, Subtarget))
+ return V;
+
// TESTZ(AND(X,Y),AND(X,Y)) == TESTZ(X,Y)
if (BC.getOpcode() == ISD::AND || BC.getOpcode() == X86ISD::FAND) {
return DAG.getNode(EFLAGS.getOpcode(), SDLoc(EFLAGS), VT,
diff --git a/llvm/test/CodeGen/X86/vector-reduce-or-cmp.ll b/llvm/test/CodeGen/X86/vector-reduce-or-cmp.ll
index 227e000c6be7f..2c7399a1a1fad 100644
--- a/llvm/test/CodeGen/X86/vector-reduce-or-cmp.ll
+++ b/llvm/test/CodeGen/X86/vector-reduce-or-cmp.ll
@@ -875,28 +875,12 @@ define i1 @mask_v8i32(<8 x i32> %a0) {
; SSE41-NEXT: sete %al
; SSE41-NEXT: retq
;
-; AVX1-LABEL: mask_v8i32:
-; AVX1: # %bb.0:
-; AVX1-NEXT: vptest {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0
-; AVX1-NEXT: sete %al
-; AVX1-NEXT: vzeroupper
-; AVX1-NEXT: retq
-;
-; AVX2-LABEL: mask_v8i32:
-; AVX2: # %bb.0:
-; AVX2-NEXT: vpbroadcastq {{.*#+}} ymm1 = [9223372039002259456,9223372039002259456,9223372039002259456,9223372039002259456]
-; AVX2-NEXT: vptest %ymm1, %ymm0
-; AVX2-NEXT: sete %al
-; AVX2-NEXT: vzeroupper
-; AVX2-NEXT: retq
-;
-; AVX512-LABEL: mask_v8i32:
-; AVX512: # %bb.0:
-; AVX512-NEXT: vpbroadcastq {{.*#+}} ymm1 = [9223372039002259456,9223372039002259456,9223372039002259456,9223372039002259456]
-; AVX512-NEXT: vptest %ymm1, %ymm0
-; AVX512-NEXT: sete %al
-; AVX512-NEXT: vzeroupper
-; AVX512-NEXT: retq
+; AVX-LABEL: mask_v8i32:
+; AVX: # %bb.0:
+; AVX-NEXT: vtestps %ymm0, %ymm0
+; AVX-NEXT: sete %al
+; AVX-NEXT: vzeroupper
+; AVX-NEXT: retq
%1 = call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> %a0)
%2 = and i32 %1, 2147483648
%3 = icmp eq i32 %2, 0
@@ -965,28 +949,12 @@ define i1 @signtest_v8i32(<8 x i32> %a0) {
; SSE41-NEXT: sete %al
; SSE41-NEXT: retq
;
-; AVX1-LABEL: signtest_v8i32:
-; AVX1: # %bb.0:
-; AVX1-NEXT: vptest {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0
-; AVX1-NEXT: sete %al
-; AVX1-NEXT: vzeroupper
-; AVX1-NEXT: retq
-;
-; AVX2-LABEL: signtest_v8i32:
-; AVX2: # %bb.0:
-; AVX2-NEXT: vpbroadcastq {{.*#+}} ymm1 = [9223372039002259456,9223372039002259456,9223372039002259456,9223372039002259456]
-; AVX2-NEXT: vptest %ymm1, %ymm0
-; AVX2-NEXT: sete %al
-; AVX2-NEXT: vzeroupper
-; AVX2-NEXT: retq
-;
-; AVX512-LABEL: signtest_v8i32:
-; AVX512: # %bb.0:
-; AVX512-NEXT: vpbroadcastq {{.*#+}} ymm1 = [9223372039002259456,9223372039002259456,9223372039002259456,9223372039002259456]
-; AVX512-NEXT: vptest %ymm1, %ymm0
-; AVX512-NEXT: sete %al
-; AVX512-NEXT: vzeroupper
-; AVX512-NEXT: retq
+; AVX-LABEL: signtest_v8i32:
+; AVX: # %bb.0:
+; AVX-NEXT: vtestps %ymm0, %ymm0
+; AVX-NEXT: sete %al
+; AVX-NEXT: vzeroupper
+; AVX-NEXT: retq
%1 = call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> %a0)
%2 = icmp sgt i32 %1, -1
ret i1 %2
diff --git a/llvm/test/CodeGen/combine-ptest-to-testp.ll b/llvm/test/CodeGen/combine-ptest-to-testp.ll
new file mode 100644
index 0000000000000..7c8595f2dd756
--- /dev/null
+++ b/llvm/test/CodeGen/combine-ptest-to-testp.ll
@@ -0,0 +1,281 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=x86_64-- -mattr=+avx2 | FileCheck %s
+
+define void @combine_ptest_to_vtestps_1(<4 x i32> noundef %a) {
+; CHECK-LABEL: combine_ptest_to_vtestps_1:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vtestps %xmm0, %xmm0
+; CHECK-NEXT: jne foo@PLT # TAILCALL
+; CHECK-NEXT: # %bb.1: # %if.end
+; CHECK-NEXT: retq
+entry:
+ %and = and <4 x i32> %a, splat (i32 -2147483648)
+ %rdx.or = tail call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> %and)
+ %cmp.not = icmp eq i32 %rdx.or, 0
+ br i1 %cmp.not, label %if.end, label %if.then
+
+if.then:
+ tail call void @foo()
+ br label %if.end
+
+if.end:
+ ret void
+}
+
+define void @combine_ptest_to_vtestps_2(<4 x i32> noundef %a, <4 x i32> noundef %b) {
+; CHECK-LABEL: combine_ptest_to_vtestps_2:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vtestps %xmm1, %xmm0
+; CHECK-NEXT: jne foo@PLT # TAILCALL
+; CHECK-NEXT: # %bb.1: # %if.end
+; CHECK-NEXT: retq
+entry:
+ %and = and <4 x i32> %a, splat (i32 -2147483648)
+ %and1 = and <4 x i32> %and, %b
+ %rdx.or = tail call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> %and1)
+ %cmp.not = icmp eq i32 %rdx.or, 0
+ br i1 %cmp.not, label %if.end, label %if.then
+
+if.then:
+ tail call void @foo()
+ br label %if.end
+
+if.end:
+ ret void
+}
+
+define void @combine_ptest_to_vtestps_3(<4 x i32> noundef %a, <4 x i32> noundef %b) {
+; CHECK-LABEL: combine_ptest_to_vtestps_3:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vtestps %xmm1, %xmm0
+; CHECK-NEXT: jae foo@PLT # TAILCALL
+; CHECK-NEXT: # %bb.1: # %if.end
+; CHECK-NEXT: retq
+entry:
+ %not = and <4 x i32> %a, splat (i32 -2147483648)
+ %and = xor <4 x i32> %not, splat (i32 -2147483648)
+ %and1 = and <4 x i32> %and, %b
+ %rdx.or = tail call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> %and1)
+ %cmp.not = icmp eq i32 %rdx.or, 0
+ br i1 %cmp.not, label %if.end, label %if.then
+
+if.then:
+ tail call void @foo()
+ br label %if.end
+
+if.end:
+ ret void
+}
+
+define void @combine_ptest_to_vtestps_4(<8 x i32> noundef %a) {
+; CHECK-LABEL: combine_ptest_to_vtestps_4:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vtestps %ymm0, %ymm0
+; CHECK-NEXT: vzeroupper
+; CHECK-NEXT: jne foo@PLT # TAILCALL
+; CHECK-NEXT: # %bb.1: # %if.end
+; CHECK-NEXT: retq
+entry:
+ %and = and <8 x i32> %a, splat (i32 -2147483648)
+ %rdx.or = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> %and)
+ %cmp.not = icmp eq i32 %rdx.or, 0
+ br i1 %cmp.not, label %if.end, label %if.then
+
+if.then:
+ tail call void @foo()
+ br label %if.end
+
+if.end:
+ ret void
+}
+
+define void @combine_ptest_to_vtestps_5(<8 x i32> noundef %a, <8 x i32> noundef %b) {
+; CHECK-LABEL: combine_ptest_to_vtestps_5:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vtestps %ymm1, %ymm0
+; CHECK-NEXT: vzeroupper
+; CHECK-NEXT: jne foo@PLT # TAILCALL
+; CHECK-NEXT: # %bb.1: # %if.end
+; CHECK-NEXT: retq
+entry:
+ %and = and <8 x i32> %a, splat (i32 -2147483648)
+ %and1 = and <8 x i32> %and, %b
+ %rdx.or = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> %and1)
+ %cmp.not = icmp eq i32 %rdx.or, 0
+ br i1 %cmp.not, label %if.end, label %if.then
+
+if.then:
+ tail call void @foo()
+ br label %if.end
+
+if.end:
+ ret void
+}
+
+define void @combine_ptest_to_vtestps_6(<8 x i32> noundef %a, <8 x i32> noundef %b) {
+; CHECK-LABEL: combine_ptest_to_vtestps_6:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vtestps %ymm1, %ymm0
+; CHECK-NEXT: vzeroupper
+; CHECK-NEXT: jae foo@PLT # TAILCALL
+; CHECK-NEXT: # %bb.1: # %if.end
+; CHECK-NEXT: retq
+entry:
+ %not = and <8 x i32> %a, splat (i32 -2147483648)
+ %and = xor <8 x i32> %not, splat (i32 -2147483648)
+ %and1 = and <8 x i32> %and, %b
+ %rdx.or = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> %and1)
+ %cmp.not = icmp eq i32 %rdx.or, 0
+ br i1 %cmp.not, label %if.end, label %if.then
+
+if.then:
+ tail call void @foo()
+ br label %if.end
+
+if.end:
+ ret void
+}
+
+define void @combine_ptest_to_vtestpd_1(<2 x i64> noundef %a) {
+; CHECK-LABEL: combine_ptest_to_vtestpd_1:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vtestpd %xmm0, %xmm0
+; CHECK-NEXT: jne foo@PLT # TAILCALL
+; CHECK-NEXT: # %bb.1: # %if.end
+; CHECK-NEXT: retq
+entry:
+ %and = and <2 x i64> %a, splat (i64 -9223372036854775808)
+ %rdx.or = tail call i64 @llvm.vector.reduce.or.v2i64(<2 x i64> %and)
+ %cmp.not = icmp eq i64 %rdx.or, 0
+ br i1 %cmp.not, label %if.end, label %if.then
+
+if.then:
+ tail call void @foo()
+ br label %if.end
+
+if.end:
+ ret void
+}
+
+define void @combine_ptest_to_vtestpd_2(<2 x i64> noundef %a, <2 x i64> noundef %b) {
+; CHECK-LABEL: combine_ptest_to_vtestpd_2:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vtestpd %xmm1, %xmm0
+; CHECK-NEXT: jne foo@PLT # TAILCALL
+; CHECK-NEXT: # %bb.1: # %if.end
+; CHECK-NEXT: retq
+entry:
+ %and = and <2 x i64> %a, splat (i64 -9223372036854775808)
+ %and1 = and <2 x i64> %and, %b
+ %rdx.or = tail call i64 @llvm.vector.reduce.or.v2i64(<2 x i64> %and1)
+ %cmp.not = icmp eq i64 %rdx.or, 0
+ br i1 %cmp.not, label %if.end, label %if.then
+
+if.then:
+ tail call void @foo()
+ br label %if.end
+
+if.end:
+ ret void
+}
+
+define void @combine_ptest_to_vtestpd_3(<2 x i64> noundef %a, <2 x i64> noundef %b) {
+; CHECK-LABEL: combine_ptest_to_vtestpd_3:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vpandn {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; CHECK-NEXT: vptest %xmm1, %xmm0
+; CHECK-NEXT: jne foo@PLT # TAILCALL
+; CHECK-NEXT: # %bb.1: # %if.end
+; CHECK-NEXT: retq
+entry:
+ %not = and <2 x i64> %a, splat (i64 -9223372036854775808)
+ %and = xor <2 x i64> %not, splat (i64 -9223372036854775808)
+ %and1 = and <2 x i64> %and, %b
+ %rdx.or = tail call i64 @llvm.vector.reduce.or.v2i64(<2 x i64> %and1)
+ %cmp.not = icmp eq i64 %rdx.or, 0
+ br i1 %cmp.not, label %if.end, label %if.then
+
+if.then:
+ tail call void @foo()
+ br label %if.end
+
+if.end:
+ ret void
+}
+
+define void @combine_ptest_to_vtestpd_4(<4 x i64> noundef %a) {
+; CHECK-LABEL: combine_ptest_to_vtestpd_4:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vtestpd %ymm0, %ymm0
+; CHECK-NEXT: vzeroupper
+; CHECK-NEXT: jne foo@PLT # TAILCALL
+; CHECK-NEXT: # %bb.1: # %if.end
+; CHECK-NEXT: retq
+entry:
+ %and = and <4 x i64> %a, splat (i64 -9223372036854775808)
+ %rdx.or = tail call i64 @llvm.vector.reduce.or.v4i64(<4 x i64> %and)
+ %cmp.not = icmp eq i64 %rdx.or, 0
+ br i1 %cmp.not, label %if.end, label %if.then
+
+if.then:
+ tail call void @foo()
+ br label %if.end
+
+if.end:
+ ret void
+}
+
+define void @combine_ptest_to_vtestpd_5(<4 x i64> noundef %a, <4 x i64> noundef %b) {
+; CHECK-LABEL: combine_ptest_to_vtestpd_5:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vtestpd %ymm1, %ymm0
+; CHECK-NEXT: vzeroupper
+; CHECK-NEXT: jne foo@PLT # TAILCALL
+; CHECK-NEXT: # %bb.1: # %if.end
+; CHECK-NEXT: retq
+entry:
+ %and = and <4 x i64> %a, splat (i64 -9223372036854775808)
+ %and1 = and <4 x i64> %and, %b
+ %rdx.or = tail call i64 @llvm.vector.reduce.or.v4i64(<4 x i64> %and1)
+ %cmp.not = icmp eq i64 %rdx.or, 0
+ br i1 %cmp.not, label %if.end, label %if.then
+
+if.then:
+ tail call void @foo()
+ br label %if.end
+
+if.end:
+ ret void
+}
+
+define void @combine_ptest_to_vtestpd_6(<4 x i64> noundef %a, <4 x i64> noundef %b) {
+; CHECK-LABEL: combine_ptest_to_vtestpd_6:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vpbroadcastq {{.*#+}} ymm2 = [9223372036854775808,9223372036854775808,9223372036854775808,9223372036854775808]
+; CHECK-NEXT: vpandn %ymm2, %ymm0, %ymm0
+; CHECK-NEXT: vptest %ymm1, %ymm0
+; CHECK-NEXT: vzeroupper
+; CHECK-NEXT: jne foo@PLT # TAILCALL
+; CHECK-NEXT: # %bb.1: # %if.end
+; CHECK-NEXT: retq
+entry:
+ %not = and <4 x i64> %a, splat (i64 -9223372036854775808)
+ %and = xor <4 x i64> %not, splat (i64 -9223372036854775808)
+ %and1 = and <4 x i64> %and, %b
+ %rdx.or = tail call i64 @llvm.vector.reduce.or.v4i64(<4 x i64> %and1)
+ %cmp.not = icmp eq i64 %rdx.or, 0
+ br i1 %cmp.not, label %if.end, label %if.then
+
+if.then:
+ tail call void @foo()
+ br label %if.end
+
+if.end:
+ ret void
+}
+
+declare void @foo()
+declare i32 @llvm.vector.reduce.or.v4i32(<4 x i32>)
+declare i32 @llvm.vector.reduce.or.v8i32(<8 x i32>)
+declare i64 @llvm.vector.reduce.or.v2i64(<2 x i64>)
+declare i64 @llvm.vector.reduce.or.v4i64(<4 x i64>)
|
|
I was not able to get this IR to work with this combine define void @combine_ptest_to_vtestpd_6(<4 x i64> noundef %a, <4 x i64> noundef %b) {
; CHECK-LABEL: combine_ptest_to_vtestpd_6:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vpbroadcastq {{.*#+}} ymm2 = [9223372036854775808,9223372036854775808,9223372036854775808,9223372036854775808]
; CHECK-NEXT: vpandn %ymm2, %ymm0, %ymm0
; CHECK-NEXT: vptest %ymm1, %ymm0
; CHECK-NEXT: vzeroupper
; CHECK-NEXT: jne foo@PLT # TAILCALL
; CHECK-NEXT: # %bb.1: # %if.end
; CHECK-NEXT: retq
entry:
%not = and <4 x i64> %a, splat (i64 -9223372036854775808)
%and = xor <4 x i64> %not, splat (i64 -9223372036854775808)
%and1 = and <4 x i64> %and, %b
%rdx.or = tail call i64 @llvm.vector.reduce.or.v4i64(<4 x i64> %and1)
%cmp.not = icmp eq i64 %rdx.or, 0
br i1 %cmp.not, label %if.end, label %if.then
if.then:
tail call void @foo()
br label %if.end
if.end:
ret void
}The DAG for this looks like @RKSimon, @phoebewang, @e-kud Can you please suggest if there's some simpler way of doing this that I'm missing? We know that the known bits will have only unknown sign bit, maybe we can use that? |
RKSimon
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have a similar fold in combinePTESTCC that uses SimplifyMultipleUseDemandedBits - any idea why it isn't firing?
I guess because of this, we should probably move this combine at the end. // TESTZ(AND(X,Y),AND(X,Y)) == TESTZ(X,Y)
if (BC.getOpcode() == ISD::AND || BC.getOpcode() == X86ISD::FAND) {
return DAG.getNode(EFLAGS.getOpcode(), SDLoc(EFLAGS), VT,
DAG.getBitcast(OpVT, BC.getOperand(0)),
DAG.getBitcast(OpVT, BC.getOperand(1)));
}There is one other which is very similar if (DAG.ComputeNumSignBits(BC) == EltBits) {
assert(VT == MVT::i32 && "Expected i32 EFLAGS comparison result");
APInt SignMask = APInt::getSignMask(EltBits);
if (SDValue Res =
TLI.SimplifyMultipleUseDemandedBits(BC, SignMask, DAG)) {
// For vXi16 cases we need to use pmovmksb and extract every other
// sign bit.
SDLoc DL(EFLAGS);
if ((EltBits == 32 || EltBits == 64) && Subtarget.hasAVX()) {
MVT FloatSVT = MVT::getFloatingPointVT(EltBits);
MVT FloatVT =
MVT::getVectorVT(FloatSVT, OpVT.getSizeInBits() / EltBits);
Res = DAG.getBitcast(FloatVT, Res);
return DAG.getNode(X86ISD::TESTP, SDLoc(EFLAGS), VT, Res, Res);
}but it checks for all sign bits but it can handle the simple case of |
|
@abhishek-kaushik22 reverse-ping |
|
Closing - this was handled in #165601 |
Combine
PTESTtoTESTPif only sign bit tested.Discovered in #156233