From 3982e37e2f79bea87af313c698f3ea3d9ae9ee63 Mon Sep 17 00:00:00 2001 From: Lauren Chin Date: Tue, 14 Oct 2025 06:20:55 -0400 Subject: [PATCH] [DAG] Fold mismatched widened avg idioms to narrow form (#147946) This fold corrects mismatched widened averaging idioms by folding: `trunc(avgceilu(sext(x), sext(y))) -> avgceils(x, y)` `trunc(avgceils(zext(x), zext(y))) -> avgceilu(x, y)` When inputs are sign-extended, unsigned and signed averaging operations produce identical results after truncation, allowing us to use the semantically correct narrow operation. alive2: https://alive2.llvm.org/ce/z/ZRbfHT --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 28 ++++++- llvm/test/CodeGen/AArch64/arm64-vhadd.ll | 82 +++++++++++++++++++ 2 files changed, 108 insertions(+), 2 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index d2ea6525e1116..1b8cd9b97451b 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -16481,10 +16481,34 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { DAG, DL); } break; - case ISD::AVGFLOORS: - case ISD::AVGFLOORU: case ISD::AVGCEILS: case ISD::AVGCEILU: + // trunc (avgceilu (sext (x), sext (y))) -> avgceils(x, y) + // trunc (avgceils (zext (x), zext (y))) -> avgceilu(x, y) + if (N0.hasOneUse()) { + SDValue Op0 = N0.getOperand(0); + SDValue Op1 = N0.getOperand(1); + if (N0.getOpcode() == ISD::AVGCEILU) { + if (TLI.isOperationLegalOrCustom(ISD::AVGCEILS, VT) && + Op0.getOpcode() == ISD::SIGN_EXTEND && + Op1.getOpcode() == ISD::SIGN_EXTEND && + Op0.getOperand(0).getValueType() == VT && + Op1.getOperand(0).getValueType() == VT) + return DAG.getNode(ISD::AVGCEILS, DL, VT, Op0.getOperand(0), + Op1.getOperand(0)); + } else { + if (TLI.isOperationLegalOrCustom(ISD::AVGCEILU, VT) && + Op0.getOpcode() == ISD::ZERO_EXTEND && + Op1.getOpcode() == ISD::ZERO_EXTEND && + Op0.getOperand(0).getValueType() == VT && + Op1.getOperand(0).getValueType() == VT) + return DAG.getNode(ISD::AVGCEILU, DL, VT, Op0.getOperand(0), + Op1.getOperand(0)); + } + } + [[fallthrough]]; + case ISD::AVGFLOORS: + case ISD::AVGFLOORU: case ISD::ABDS: case ISD::ABDU: // (trunc (avg a, b)) -> (avg (trunc a), (trunc b)) diff --git a/llvm/test/CodeGen/AArch64/arm64-vhadd.ll b/llvm/test/CodeGen/AArch64/arm64-vhadd.ll index 076cbf7fce6cc..a505b42e3423a 100644 --- a/llvm/test/CodeGen/AArch64/arm64-vhadd.ll +++ b/llvm/test/CodeGen/AArch64/arm64-vhadd.ll @@ -1408,6 +1408,88 @@ define <4 x i16> @ext_via_i19(<4 x i16> %a) { ret <4 x i16> %t6 } +define <8 x i8> @srhadd_v8i8_trunc(<8 x i8> %s0, <8 x i8> %s1) { +; CHECK-LABEL: srhadd_v8i8_trunc: +; CHECK: // %bb.0: +; CHECK-NEXT: srhadd.8b v0, v0, v1 +; CHECK-NEXT: ret + %s0s = sext <8 x i8> %s0 to <8 x i16> + %s1s = sext <8 x i8> %s1 to <8 x i16> + %s = call <8 x i16> @llvm.aarch64.neon.urhadd.v8i16(<8 x i16> %s0s, <8 x i16> %s1s) + %s2 = trunc <8 x i16> %s to <8 x i8> + ret <8 x i8> %s2 +} + +define <4 x i16> @srhadd_v4i16_trunc(<4 x i16> %s0, <4 x i16> %s1) { +; CHECK-LABEL: srhadd_v4i16_trunc: +; CHECK: // %bb.0: +; CHECK-NEXT: srhadd.4h v0, v0, v1 +; CHECK-NEXT: ret + %s0s = sext <4 x i16> %s0 to <4 x i32> + %s1s = sext <4 x i16> %s1 to <4 x i32> + %s = call <4 x i32> @llvm.aarch64.neon.urhadd.v4i32(<4 x i32> %s0s, <4 x i32> %s1s) + %s2 = trunc <4 x i32> %s to <4 x i16> + ret <4 x i16> %s2 +} + +define <2 x i32> @srhadd_v2i32_trunc(<2 x i32> %s0, <2 x i32> %s1) { +; CHECK-LABEL: srhadd_v2i32_trunc: +; CHECK: // %bb.0: +; CHECK-NEXT: sshll.2d v0, v0, #0 +; CHECK-NEXT: sshll.2d v1, v1, #0 +; CHECK-NEXT: eor.16b v2, v0, v1 +; CHECK-NEXT: orr.16b v0, v0, v1 +; CHECK-NEXT: ushr.2d v1, v2, #1 +; CHECK-NEXT: sub.2d v0, v0, v1 +; CHECK-NEXT: xtn.2s v0, v0 +; CHECK-NEXT: ret + %s0s = sext <2 x i32> %s0 to <2 x i64> + %s1s = sext <2 x i32> %s1 to <2 x i64> + %s = call <2 x i64> @llvm.aarch64.neon.urhadd.v2i64(<2 x i64> %s0s, <2 x i64> %s1s) + %s2 = trunc <2 x i64> %s to <2 x i32> + ret <2 x i32> %s2 +} + +define <8 x i8> @urhadd_v8i8_trunc(<8 x i8> %s0, <8 x i8> %s1) { +; CHECK-LABEL: urhadd_v8i8_trunc: +; CHECK: // %bb.0: +; CHECK-NEXT: urhadd.8b v0, v0, v1 +; CHECK-NEXT: ret + %s0s = zext <8 x i8> %s0 to <8 x i16> + %s1s = zext <8 x i8> %s1 to <8 x i16> + %s = call <8 x i16> @llvm.aarch64.neon.srhadd.v8i16(<8 x i16> %s0s, <8 x i16> %s1s) + %s2 = trunc <8 x i16> %s to <8 x i8> + ret <8 x i8> %s2 +} + +define <4 x i16> @urhadd_v4i16_trunc(<4 x i16> %s0, <4 x i16> %s1) { +; CHECK-LABEL: urhadd_v4i16_trunc: +; CHECK: // %bb.0: +; CHECK-NEXT: urhadd.4h v0, v0, v1 +; CHECK-NEXT: ret + %s0s = zext <4 x i16> %s0 to <4 x i32> + %s1s = zext <4 x i16> %s1 to <4 x i32> + %s = call <4 x i32> @llvm.aarch64.neon.srhadd.v4i32(<4 x i32> %s0s, <4 x i32> %s1s) + %s2 = trunc <4 x i32> %s to <4 x i16> + ret <4 x i16> %s2 +} + +define <2 x i32> @urhadd_v2i32_trunc(<2 x i32> %s0, <2 x i32> %s1) { +; CHECK-LABEL: urhadd_v2i32_trunc: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w8, #1 // =0x1 +; CHECK-NEXT: uaddl.2d v0, v0, v1 +; CHECK-NEXT: dup.2d v1, x8 +; CHECK-NEXT: add.2d v0, v0, v1 +; CHECK-NEXT: shrn.2s v0, v0, #1 +; CHECK-NEXT: ret + %s0s = zext <2 x i32> %s0 to <2 x i64> + %s1s = zext <2 x i32> %s1 to <2 x i64> + %s = call <2 x i64> @llvm.aarch64.neon.srhadd.v2i64(<2 x i64> %s0s, <2 x i64> %s1s) + %s2 = trunc <2 x i64> %s to <2 x i32> + ret <2 x i32> %s2 +} + declare <8 x i8> @llvm.aarch64.neon.srhadd.v8i8(<8 x i8>, <8 x i8>) declare <4 x i16> @llvm.aarch64.neon.srhadd.v4i16(<4 x i16>, <4 x i16>) declare <2 x i32> @llvm.aarch64.neon.srhadd.v2i32(<2 x i32>, <2 x i32>)