-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[DAG] Fold trunc(avg(x,y)) for avgceil/floor u/s nodes if they have sufficient leading zero/sign bits #152273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ufficient leading zero/sign bits-1
|
@llvm/pr-subscribers-llvm-selectiondag @llvm/pr-subscribers-backend-aarch64 Author: 黃國庭 (houngkoungting) Changesavgceil version : https://alive2.llvm.org/ce/z/2CKrRh @RKSimon Full diff: https://github.com/llvm/llvm-project/pull/152273.diff 2 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d70e96938ed9a..9ff256f8090ba 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -16294,6 +16294,51 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
// because targets may prefer a wider type during later combines and invert
// this transform.
switch (N0.getOpcode()) {
+ case ISD::AVGCEILU:
+ case ISD::AVGFLOORU:
+ if (!LegalOperations && N0.hasOneUse() &&
+ TLI.isOperationLegal(N0.getOpcode(), VT)) {
+ SDValue X = N0.getOperand(0);
+ SDValue Y = N0.getOperand(1);
+
+ KnownBits KnownX = DAG.computeKnownBits(X);
+ KnownBits KnownY = DAG.computeKnownBits(Y);
+
+ unsigned SrcBits = X.getScalarValueSizeInBits();
+ unsigned DstBits = VT.getScalarSizeInBits();
+ unsigned NeededLeadingZeros = SrcBits - DstBits + 1;
+
+ if (KnownX.countMinLeadingZeros() >= NeededLeadingZeros &&
+ KnownY.countMinLeadingZeros() >= NeededLeadingZeros) {
+ SDValue Tx = DAG.getNode(ISD::TRUNCATE, DL, VT, X);
+ SDValue Ty = DAG.getNode(ISD::TRUNCATE, DL, VT, Y);
+ return DAG.getNode(N0.getOpcode(), DL, VT, Tx, Ty);
+ }
+ }
+ break;
+
+ case ISD::AVGCEILS:
+ case ISD::AVGFLOORS:
+ if (!LegalOperations && N0.hasOneUse() &&
+ TLI.isOperationLegal(N0.getOpcode(), VT)) {
+ SDValue X = N0.getOperand(0);
+ SDValue Y = N0.getOperand(1);
+
+ unsigned SignBitsX = DAG.ComputeNumSignBits(X);
+ unsigned SignBitsY = DAG.ComputeNumSignBits(Y);
+
+ unsigned SrcBits = X.getScalarValueSizeInBits();
+ unsigned DstBits = VT.getScalarSizeInBits();
+ unsigned NeededSignBits = SrcBits - DstBits + 1;
+
+ if (SignBitsX >= NeededSignBits && SignBitsY >= NeededSignBits) {
+ SDValue Tx = DAG.getNode(ISD::TRUNCATE, DL, VT, X);
+ SDValue Ty = DAG.getNode(ISD::TRUNCATE, DL, VT, Y);
+ return DAG.getNode(N0.getOpcode(), DL, VT, Tx, Ty);
+ }
+ }
+ break;
+
case ISD::ADD:
case ISD::SUB:
case ISD::MUL:
diff --git a/llvm/test/CodeGen/AArch64/trunc-avg-fold.ll b/llvm/test/CodeGen/AArch64/trunc-avg-fold.ll
new file mode 100644
index 0000000000000..175f54d6f9c05
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/trunc-avg-fold.ll
@@ -0,0 +1,43 @@
+; RUN: llc -mtriple=aarch64-- -O2 -mattr=+neon < %s | FileCheck %s
+
+; CHECK-LABEL: test_avgceil_u
+; CHECK: uhadd v0.8b, v0.8b, v1.8b
+define <8 x i8> @test_avgceil_u(<8 x i16> %a, <8 x i16> %b) {
+ %ta = trunc <8 x i16> %a to <8 x i8>
+ %tb = trunc <8 x i16> %b to <8 x i8>
+ %res = call <8 x i8> @llvm.aarch64.neon.uhadd.v8i8(<8 x i8> %ta, <8 x i8> %tb)
+ ret <8 x i8> %res
+}
+
+; CHECK-LABEL: test_avgceil_s
+; CHECK: shadd v0.8b, v0.8b, v1.8b
+define <8 x i8> @test_avgceil_s(<8 x i16> %a, <8 x i16> %b) {
+ %ta = trunc <8 x i16> %a to <8 x i8>
+ %tb = trunc <8 x i16> %b to <8 x i8>
+ %res = call <8 x i8> @llvm.aarch64.neon.shadd.v8i8(<8 x i8> %ta, <8 x i8> %tb)
+ ret <8 x i8> %res
+}
+
+; CHECK-LABEL: test_avgfloor_u
+; CHECK: urhadd v0.8b, v0.8b, v1.8b
+define <8 x i8> @test_avgfloor_u(<8 x i16> %a, <8 x i16> %b) {
+ %ta = trunc <8 x i16> %a to <8 x i8>
+ %tb = trunc <8 x i16> %b to <8 x i8>
+ %res = call <8 x i8> @llvm.aarch64.neon.urhadd.v8i8(<8 x i8> %ta, <8 x i8> %tb)
+ ret <8 x i8> %res
+}
+
+; CHECK-LABEL: test_avgfloor_s
+; CHECK: srhadd v0.8b, v0.8b, v1.8b
+define <8 x i8> @test_avgfloor_s(<8 x i16> %a, <8 x i16> %b) {
+ %ta = trunc <8 x i16> %a to <8 x i8>
+ %tb = trunc <8 x i16> %b to <8 x i8>
+ %res = call <8 x i8> @llvm.aarch64.neon.srhadd.v8i8(<8 x i8> %ta, <8 x i8> %tb)
+ ret <8 x i8> %res
+}
+
+declare <8 x i8> @llvm.aarch64.neon.uhadd.v8i8(<8 x i8>, <8 x i8>)
+declare <8 x i8> @llvm.aarch64.neon.shadd.v8i8(<8 x i8>, <8 x i8>)
+declare <8 x i8> @llvm.aarch64.neon.urhadd.v8i8(<8 x i8>, <8 x i8>)
+declare <8 x i8> @llvm.aarch64.neon.srhadd.v8i8(<8 x i8>, <8 x i8>)
+
|
|
|
||
| unsigned SrcBits = X.getScalarValueSizeInBits(); | ||
| unsigned DstBits = VT.getScalarSizeInBits(); | ||
| unsigned NeededLeadingZeros = SrcBits - DstBits + 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NeededLeadingZeros = SrcBits - DstBits; ? (NeededSignBits is correct though you could use ComputeMaxSignificantBits instead if you wish)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I think you misunderstood - you need to use computeKnownBits.countMinLeadingZeros() >= (SrcBits - DstBits)
…ufficient leading zero/sign bits -2
|
✅ With the latest revision this PR passed the undef deprecator. |
|
I will fix it tomorrow |
|
|
||
| unsigned SrcBits = X.getScalarValueSizeInBits(); | ||
| unsigned DstBits = VT.getScalarSizeInBits(); | ||
| unsigned NeededLeadingZeros = SrcBits - DstBits + 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I think you misunderstood - you need to use computeKnownBits.countMinLeadingZeros() >= (SrcBits - DstBits)
…ufficient leading zero/sign bits -3
…ufficient leading zero/sign bits-4
…ufficient leading zero/sign bits-5
…ufficient leading zero/sign bits-6
| %mask = insertelement <8 x i16> poison, i16 255, i32 0 | ||
| %mask.splat = shufflevector <8 x i16> %mask, <8 x i16> poison, <8 x i32> zeroinitializer | ||
| %ta16 = and <8 x i16> %a, %mask.splat | ||
| %tb16 = and <8 x i16> %b, %mask.splat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not use splat (i16 255)? it was added to avoid the messy shufflevector(insertelement) pattern
…ufficient leading zero/sign bits-7
|
HI @RKSimon , did I get this right? |
| return DAG.getNode(N0.getOpcode(), DL, VT, Tx, Ty); | ||
| } | ||
| } | ||
| break; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to reuse the ISD::ABD code later in the switch statement now - its has near-identical logic
…ufficient leading zero/sign bits-8
…ufficient leading zero/sign bits-9
|
HI @RKSimon , I update the test cases first; I’ll modify the DAG code tomorrow. |
…ent leading zero/sign bits-10
| if (KnownX.countMinLeadingZeros() >= (SrcBits - DstBits) && | ||
| KnownY.countMinLeadingZeros() >= (SrcBits - DstBits)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if (KnownX.countMinLeadingZeros() >= (SrcBits - DstBits) && | |
| KnownY.countMinLeadingZeros() >= (SrcBits - DstBits)) { | |
| if (KnownX.countMaxActiveBits() <= DstBits && | |
| KnownY.countMaxActiveBits() <= DstBits) { |
Then you don't need SrcBits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could even do this:
APInt UpperBits = APInt::getHighBitsSet(SrcBits, SrcBits - DstBits);
if (DAG.MaskedValueIsZero(X, UpperBits) &&
DAG.MaskedValueIsZero(Y, UpperBits)) {
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or APInt::getBitsSetFrom(SrcBits, DstBits). (Sometimes I think we have too many different helper functions!)
| KnownBits KnownX = DAG.computeKnownBits(X); | ||
| KnownBits KnownY = DAG.computeKnownBits(Y); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
computeKnownBits can be expensive. You should rearrange this code so that you only call computeKnownBits(Y) if the test on the result of computeKnownBits(X) succeeds.
RKSimon
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
couple of suggestions to reduce value tracking costs
| unsigned DstBits = VT.getScalarSizeInBits(); | ||
| unsigned NeededSignBits = SrcBits - DstBits + 1; | ||
|
|
||
| if (SignBitsX >= NeededSignBits && SignBitsY >= NeededSignBits) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (DAG.ComputeNumSignBits(X) >= NeededSignBits &&
DAG.ComputeNumSignBits(Y) >= NeededSignBits) {
| if (KnownX.countMinLeadingZeros() >= (SrcBits - DstBits) && | ||
| KnownY.countMinLeadingZeros() >= (SrcBits - DstBits)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could even do this:
APInt UpperBits = APInt::getHighBitsSet(SrcBits, SrcBits - DstBits);
if (DAG.MaskedValueIsZero(X, UpperBits) &&
DAG.MaskedValueIsZero(Y, UpperBits)) {
…ufficient leading zero/sign bits-11
RKSimon
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - cheers
avgceil version : https://alive2.llvm.org/ce/z/2CKrRh
Fix #147773 , After several iterations, I believe this version is correct and complete.
@RKSimon