Skip to content

Commit 0773854

Browse files
[DAG] Fold trunc(avg(x,y)) for avgceil/floor u/s nodes if they have sufficient leading zero/sign bits (#152273)
avgceil version : https://alive2.llvm.org/ce/z/2CKrRh Fixes #147773 --------- Co-authored-by: Simon Pilgrim <[email protected]>
1 parent d12f58f commit 0773854

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16279,6 +16279,40 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
1627916279
// because targets may prefer a wider type during later combines and invert
1628016280
// this transform.
1628116281
switch (N0.getOpcode()) {
16282+
case ISD::AVGCEILU:
16283+
case ISD::AVGFLOORU:
16284+
if (!LegalOperations && N0.hasOneUse() &&
16285+
TLI.isOperationLegal(N0.getOpcode(), VT)) {
16286+
SDValue X = N0.getOperand(0);
16287+
SDValue Y = N0.getOperand(1);
16288+
unsigned SrcBits = X.getScalarValueSizeInBits();
16289+
unsigned DstBits = VT.getScalarSizeInBits();
16290+
APInt UpperBits = APInt::getBitsSetFrom(SrcBits, DstBits);
16291+
if (DAG.MaskedValueIsZero(X, UpperBits) &&
16292+
DAG.MaskedValueIsZero(Y, UpperBits)) {
16293+
SDValue Tx = DAG.getNode(ISD::TRUNCATE, DL, VT, X);
16294+
SDValue Ty = DAG.getNode(ISD::TRUNCATE, DL, VT, Y);
16295+
return DAG.getNode(N0.getOpcode(), DL, VT, Tx, Ty);
16296+
}
16297+
}
16298+
break;
16299+
case ISD::AVGCEILS:
16300+
case ISD::AVGFLOORS:
16301+
if (!LegalOperations && N0.hasOneUse() &&
16302+
TLI.isOperationLegal(N0.getOpcode(), VT)) {
16303+
SDValue X = N0.getOperand(0);
16304+
SDValue Y = N0.getOperand(1);
16305+
unsigned SrcBits = X.getScalarValueSizeInBits();
16306+
unsigned DstBits = VT.getScalarSizeInBits();
16307+
unsigned NeededSignBits = SrcBits - DstBits + 1;
16308+
if (DAG.ComputeNumSignBits(X) >= NeededSignBits &&
16309+
DAG.ComputeNumSignBits(Y) >= NeededSignBits) {
16310+
SDValue Tx = DAG.getNode(ISD::TRUNCATE, DL, VT, X);
16311+
SDValue Ty = DAG.getNode(ISD::TRUNCATE, DL, VT, Y);
16312+
return DAG.getNode(N0.getOpcode(), DL, VT, Tx, Ty);
16313+
}
16314+
}
16315+
break;
1628216316
case ISD::ADD:
1628316317
case ISD::SUB:
1628416318
case ISD::MUL:
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mtriple=aarch64-- -O2 -mattr=+neon < %s | FileCheck %s
3+
4+
define <8 x i8> @avgceil_u_i8_to_i16(<8 x i8> %a, <8 x i8> %b) {
5+
; CHECK-LABEL: avgceil_u_i8_to_i16:
6+
; CHECK: // %bb.0:
7+
; CHECK-NEXT: urhadd v0.8b, v0.8b, v1.8b
8+
; CHECK-NEXT: ret
9+
%a16 = zext <8 x i8> %a to <8 x i16>
10+
%b16 = zext <8 x i8> %b to <8 x i16>
11+
%avg16 = call <8 x i16> @llvm.aarch64.neon.urhadd.v8i16(<8 x i16> %a16, <8 x i16> %b16)
12+
%r = trunc <8 x i16> %avg16 to <8 x i8>
13+
ret <8 x i8> %r
14+
}
15+
16+
17+
define <8 x i8> @test_avgceil_s(<8 x i8> %a, <8 x i8> %b) {
18+
; CHECK-LABEL: test_avgceil_s:
19+
; CHECK: // %bb.0:
20+
; CHECK-NEXT: srhadd v0.8b, v0.8b, v1.8b
21+
; CHECK-NEXT: ret
22+
%a16 = sext <8 x i8> %a to <8 x i16>
23+
%b16 = sext <8 x i8> %b to <8 x i16>
24+
%avg16 = call <8 x i16> @llvm.aarch64.neon.srhadd.v8i16(<8 x i16> %a16, <8 x i16> %b16)
25+
%res = trunc <8 x i16> %avg16 to <8 x i8>
26+
ret <8 x i8> %res
27+
}
28+
29+
define <8 x i8> @avgfloor_u_i8_to_i16(<8 x i8> %a, <8 x i8> %b) {
30+
; CHECK-LABEL: avgfloor_u_i8_to_i16:
31+
; CHECK: // %bb.0:
32+
; CHECK-NEXT: uhadd v0.8b, v0.8b, v1.8b
33+
; CHECK-NEXT: ret
34+
%a16 = zext <8 x i8> %a to <8 x i16>
35+
%b16 = zext <8 x i8> %b to <8 x i16>
36+
%avg16 = call <8 x i16> @llvm.aarch64.neon.uhadd.v8i16(<8 x i16> %a16, <8 x i16> %b16)
37+
%res = trunc <8 x i16> %avg16 to <8 x i8>
38+
ret <8 x i8> %res
39+
}
40+
41+
define <8 x i8> @test_avgfloor_s(<8 x i8> %a, <8 x i8> %b) {
42+
; CHECK-LABEL: test_avgfloor_s:
43+
; CHECK: // %bb.0:
44+
; CHECK-NEXT: shadd v0.8b, v0.8b, v1.8b
45+
; CHECK-NEXT: ret
46+
%a16 = sext <8 x i8> %a to <8 x i16>
47+
%b16 = sext <8 x i8> %b to <8 x i16>
48+
%avg16 = call <8 x i16> @llvm.aarch64.neon.shadd.v8i16(<8 x i16> %a16, <8 x i16> %b16)
49+
%res = trunc <8 x i16> %avg16 to <8 x i8>
50+
ret <8 x i8> %res
51+
}
52+
53+

0 commit comments

Comments
 (0)