Skip to content

Conversation

@houngkoungting
Copy link
Contributor

avgceil version : https://alive2.llvm.org/ce/z/2CKrRh
Fix #147773 , After several iterations, I believe this version is correct and complete.

@RKSimon

@llvmbot llvmbot added backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well labels Aug 6, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 6, 2025

@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-aarch64

Author: 黃國庭 (houngkoungting)

Changes

avgceil version : https://alive2.llvm.org/ce/z/2CKrRh
Fix #147773 , After several iterations, I believe this version is correct and complete.

@RKSimon


Full diff: https://github.com/llvm/llvm-project/pull/152273.diff

2 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+45)
  • (added) llvm/test/CodeGen/AArch64/trunc-avg-fold.ll (+43)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d70e96938ed9a..9ff256f8090ba 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -16294,6 +16294,51 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
   // because targets may prefer a wider type during later combines and invert
   // this transform.
   switch (N0.getOpcode()) {
+  case ISD::AVGCEILU:
+  case ISD::AVGFLOORU:
+    if (!LegalOperations && N0.hasOneUse() &&
+        TLI.isOperationLegal(N0.getOpcode(), VT)) {
+      SDValue X = N0.getOperand(0);
+      SDValue Y = N0.getOperand(1);
+
+      KnownBits KnownX = DAG.computeKnownBits(X);
+      KnownBits KnownY = DAG.computeKnownBits(Y);
+
+      unsigned SrcBits = X.getScalarValueSizeInBits();
+      unsigned DstBits = VT.getScalarSizeInBits();
+      unsigned NeededLeadingZeros = SrcBits - DstBits + 1;
+
+      if (KnownX.countMinLeadingZeros() >= NeededLeadingZeros &&
+          KnownY.countMinLeadingZeros() >= NeededLeadingZeros) {
+        SDValue Tx = DAG.getNode(ISD::TRUNCATE, DL, VT, X);
+        SDValue Ty = DAG.getNode(ISD::TRUNCATE, DL, VT, Y);
+        return DAG.getNode(N0.getOpcode(), DL, VT, Tx, Ty);
+      }
+    }
+    break;
+
+  case ISD::AVGCEILS:
+  case ISD::AVGFLOORS:
+    if (!LegalOperations && N0.hasOneUse() &&
+        TLI.isOperationLegal(N0.getOpcode(), VT)) {
+      SDValue X = N0.getOperand(0);
+      SDValue Y = N0.getOperand(1);
+
+      unsigned SignBitsX = DAG.ComputeNumSignBits(X);
+      unsigned SignBitsY = DAG.ComputeNumSignBits(Y);
+
+      unsigned SrcBits = X.getScalarValueSizeInBits();
+      unsigned DstBits = VT.getScalarSizeInBits();
+      unsigned NeededSignBits = SrcBits - DstBits + 1;
+
+      if (SignBitsX >= NeededSignBits && SignBitsY >= NeededSignBits) {
+        SDValue Tx = DAG.getNode(ISD::TRUNCATE, DL, VT, X);
+        SDValue Ty = DAG.getNode(ISD::TRUNCATE, DL, VT, Y);
+        return DAG.getNode(N0.getOpcode(), DL, VT, Tx, Ty);
+      }
+    }
+    break;
+
   case ISD::ADD:
   case ISD::SUB:
   case ISD::MUL:
diff --git a/llvm/test/CodeGen/AArch64/trunc-avg-fold.ll b/llvm/test/CodeGen/AArch64/trunc-avg-fold.ll
new file mode 100644
index 0000000000000..175f54d6f9c05
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/trunc-avg-fold.ll
@@ -0,0 +1,43 @@
+; RUN: llc -mtriple=aarch64-- -O2 -mattr=+neon < %s | FileCheck %s
+
+; CHECK-LABEL: test_avgceil_u
+; CHECK: uhadd v0.8b, v0.8b, v1.8b
+define <8 x i8> @test_avgceil_u(<8 x i16> %a, <8 x i16> %b) {
+  %ta = trunc <8 x i16> %a to <8 x i8>
+  %tb = trunc <8 x i16> %b to <8 x i8>
+  %res = call <8 x i8> @llvm.aarch64.neon.uhadd.v8i8(<8 x i8> %ta, <8 x i8> %tb)
+  ret <8 x i8> %res
+}
+
+; CHECK-LABEL: test_avgceil_s
+; CHECK: shadd v0.8b, v0.8b, v1.8b
+define <8 x i8> @test_avgceil_s(<8 x i16> %a, <8 x i16> %b) {
+  %ta = trunc <8 x i16> %a to <8 x i8>
+  %tb = trunc <8 x i16> %b to <8 x i8>
+  %res = call <8 x i8> @llvm.aarch64.neon.shadd.v8i8(<8 x i8> %ta, <8 x i8> %tb)
+  ret <8 x i8> %res
+}
+
+; CHECK-LABEL: test_avgfloor_u
+; CHECK: urhadd v0.8b, v0.8b, v1.8b
+define <8 x i8> @test_avgfloor_u(<8 x i16> %a, <8 x i16> %b) {
+  %ta = trunc <8 x i16> %a to <8 x i8>
+  %tb = trunc <8 x i16> %b to <8 x i8>
+  %res = call <8 x i8> @llvm.aarch64.neon.urhadd.v8i8(<8 x i8> %ta, <8 x i8> %tb)
+  ret <8 x i8> %res
+}
+
+; CHECK-LABEL: test_avgfloor_s
+; CHECK: srhadd v0.8b, v0.8b, v1.8b
+define <8 x i8> @test_avgfloor_s(<8 x i16> %a, <8 x i16> %b) {
+  %ta = trunc <8 x i16> %a to <8 x i8>
+  %tb = trunc <8 x i16> %b to <8 x i8>
+  %res = call <8 x i8> @llvm.aarch64.neon.srhadd.v8i8(<8 x i8> %ta, <8 x i8> %tb)
+  ret <8 x i8> %res
+}
+
+declare <8 x i8> @llvm.aarch64.neon.uhadd.v8i8(<8 x i8>, <8 x i8>)
+declare <8 x i8> @llvm.aarch64.neon.shadd.v8i8(<8 x i8>, <8 x i8>)
+declare <8 x i8> @llvm.aarch64.neon.urhadd.v8i8(<8 x i8>, <8 x i8>)
+declare <8 x i8> @llvm.aarch64.neon.srhadd.v8i8(<8 x i8>, <8 x i8>)
+

@houngkoungting houngkoungting changed the title [DAG] Fold trunc(avg(x,y)) for avgceil/floor u/s nodes if they have s… [DAG] Fold trunc(avg(x,y)) for avgceil/floor u/s nodes if they have sufficient leading zero/sign bits Aug 6, 2025

unsigned SrcBits = X.getScalarValueSizeInBits();
unsigned DstBits = VT.getScalarSizeInBits();
unsigned NeededLeadingZeros = SrcBits - DstBits + 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NeededLeadingZeros = SrcBits - DstBits; ? (NeededSignBits is correct though you could use ComputeMaxSignificantBits instead if you wish)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I think you misunderstood - you need to use computeKnownBits.countMinLeadingZeros() >= (SrcBits - DstBits)

@github-actions
Copy link

github-actions bot commented Aug 7, 2025

✅ With the latest revision this PR passed the undef deprecator.

@houngkoungting
Copy link
Contributor Author

I will fix it tomorrow


unsigned SrcBits = X.getScalarValueSizeInBits();
unsigned DstBits = VT.getScalarSizeInBits();
unsigned NeededLeadingZeros = SrcBits - DstBits + 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I think you misunderstood - you need to use computeKnownBits.countMinLeadingZeros() >= (SrcBits - DstBits)

%mask = insertelement <8 x i16> poison, i16 255, i32 0
%mask.splat = shufflevector <8 x i16> %mask, <8 x i16> poison, <8 x i32> zeroinitializer
%ta16 = and <8 x i16> %a, %mask.splat
%tb16 = and <8 x i16> %b, %mask.splat
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use splat (i16 255)? it was added to avoid the messy shufflevector(insertelement) pattern

@houngkoungting
Copy link
Contributor Author

HI @RKSimon , did I get this right?

return DAG.getNode(N0.getOpcode(), DL, VT, Tx, Ty);
}
}
break;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to reuse the ISD::ABD code later in the switch statement now - its has near-identical logic

@houngkoungting
Copy link
Contributor Author

HI @RKSimon , I update the test cases first; I’ll modify the DAG code tomorrow.

Comment on lines 16307 to 16308
if (KnownX.countMinLeadingZeros() >= (SrcBits - DstBits) &&
KnownY.countMinLeadingZeros() >= (SrcBits - DstBits)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (KnownX.countMinLeadingZeros() >= (SrcBits - DstBits) &&
KnownY.countMinLeadingZeros() >= (SrcBits - DstBits)) {
if (KnownX.countMaxActiveBits() <= DstBits &&
KnownY.countMaxActiveBits() <= DstBits) {

Then you don't need SrcBits.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could even do this:

APInt UpperBits = APInt::getHighBitsSet(SrcBits, SrcBits - DstBits);
if (DAG.MaskedValueIsZero(X, UpperBits) &&
    DAG.MaskedValueIsZero(Y, UpperBits)) {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or APInt::getBitsSetFrom(SrcBits, DstBits). (Sometimes I think we have too many different helper functions!)

Comment on lines 16305 to 16306
KnownBits KnownX = DAG.computeKnownBits(X);
KnownBits KnownY = DAG.computeKnownBits(Y);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

computeKnownBits can be expensive. You should rearrange this code so that you only call computeKnownBits(Y) if the test on the result of computeKnownBits(X) succeeds.

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couple of suggestions to reduce value tracking costs

unsigned DstBits = VT.getScalarSizeInBits();
unsigned NeededSignBits = SrcBits - DstBits + 1;

if (SignBitsX >= NeededSignBits && SignBitsY >= NeededSignBits) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

 if (DAG.ComputeNumSignBits(X) >= NeededSignBits &&
     DAG.ComputeNumSignBits(Y) >= NeededSignBits) {

Comment on lines 16307 to 16308
if (KnownX.countMinLeadingZeros() >= (SrcBits - DstBits) &&
KnownY.countMinLeadingZeros() >= (SrcBits - DstBits)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could even do this:

APInt UpperBits = APInt::getHighBitsSet(SrcBits, SrcBits - DstBits);
if (DAG.MaskedValueIsZero(X, UpperBits) &&
    DAG.MaskedValueIsZero(Y, UpperBits)) {

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - cheers

@houngkoungting
Copy link
Contributor Author

houngkoungting commented Aug 18, 2025

@RKSimon @jayfoad Thank you both for your review

@RKSimon RKSimon merged commit 0773854 into llvm:main Aug 18, 2025
8 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[DAG] Fold trunc(avg(x,y)) for avgceil/floor u/s nodes if they have sufficient leading zero/sign bits

4 participants