Skip to content

Commit 24287f7

Browse files
[DAG] Fold trunc(avg(x,y)) for avgceil/floor u/s nodes if they have sufficient leading zero/sign bits -2
1 parent 80e303c commit 24287f7

File tree

2 files changed

+79
-33
lines changed

2 files changed

+79
-33
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16300,45 +16300,35 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
1630016300
TLI.isOperationLegal(N0.getOpcode(), VT)) {
1630116301
SDValue X = N0.getOperand(0);
1630216302
SDValue Y = N0.getOperand(1);
16303-
16304-
KnownBits KnownX = DAG.computeKnownBits(X);
16305-
KnownBits KnownY = DAG.computeKnownBits(Y);
16306-
1630716303
unsigned SrcBits = X.getScalarValueSizeInBits();
1630816304
unsigned DstBits = VT.getScalarSizeInBits();
16309-
unsigned NeededLeadingZeros = SrcBits - DstBits + 1;
16310-
16311-
if (KnownX.countMinLeadingZeros() >= NeededLeadingZeros &&
16312-
KnownY.countMinLeadingZeros() >= NeededLeadingZeros) {
16305+
unsigned MaxBitsX = DAG.ComputeMaxSignificantBits(X);
16306+
unsigned MaxBitsY = DAG.ComputeMaxSignificantBits(Y);
16307+
if (MaxBitsX <= DstBits && MaxBitsY <= DstBits) {
1631316308
SDValue Tx = DAG.getNode(ISD::TRUNCATE, DL, VT, X);
1631416309
SDValue Ty = DAG.getNode(ISD::TRUNCATE, DL, VT, Y);
1631516310
return DAG.getNode(N0.getOpcode(), DL, VT, Tx, Ty);
1631616311
}
1631716312
}
1631816313
break;
16319-
1632016314
case ISD::AVGCEILS:
1632116315
case ISD::AVGFLOORS:
1632216316
if (!LegalOperations && N0.hasOneUse() &&
1632316317
TLI.isOperationLegal(N0.getOpcode(), VT)) {
1632416318
SDValue X = N0.getOperand(0);
1632516319
SDValue Y = N0.getOperand(1);
16326-
1632716320
unsigned SignBitsX = DAG.ComputeNumSignBits(X);
1632816321
unsigned SignBitsY = DAG.ComputeNumSignBits(Y);
16329-
1633016322
unsigned SrcBits = X.getScalarValueSizeInBits();
1633116323
unsigned DstBits = VT.getScalarSizeInBits();
1633216324
unsigned NeededSignBits = SrcBits - DstBits + 1;
16333-
1633416325
if (SignBitsX >= NeededSignBits && SignBitsY >= NeededSignBits) {
1633516326
SDValue Tx = DAG.getNode(ISD::TRUNCATE, DL, VT, X);
1633616327
SDValue Ty = DAG.getNode(ISD::TRUNCATE, DL, VT, Y);
1633716328
return DAG.getNode(N0.getOpcode(), DL, VT, Tx, Ty);
1633816329
}
1633916330
}
1634016331
break;
16341-
1634216332
case ISD::ADD:
1634316333
case ISD::SUB:
1634416334
case ISD::MUL:
Lines changed: 76 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,91 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
12
; RUN: llc -mtriple=aarch64-- -O2 -mattr=+neon < %s | FileCheck %s
23

3-
; CHECK-LABEL: test_avgceil_u
4-
; CHECK: uhadd v0.8b, v0.8b, v1.8b
4+
55
define <8 x i8> @test_avgceil_u(<8 x i16> %a, <8 x i16> %b) {
6-
%ta = trunc <8 x i16> %a to <8 x i8>
7-
%tb = trunc <8 x i16> %b to <8 x i8>
8-
%res = call <8 x i8> @llvm.aarch64.neon.uhadd.v8i8(<8 x i8> %ta, <8 x i8> %tb)
6+
; CHECK-LABEL: test_avgceil_u:
7+
; CHECK: // %bb.0:
8+
; CHECK-NEXT: xtn v0.8b, v0.8h
9+
; CHECK-NEXT: xtn v1.8b, v1.8h
10+
; CHECK-NEXT: uhadd v0.8b, v0.8b, v1.8b
11+
; CHECK-NEXT: ret
12+
13+
%mask = insertelement <8 x i16> undef, i16 255, i32 0
14+
%mask.splat = shufflevector <8 x i16> %mask, <8 x i16> undef, <8 x i32> zeroinitializer
15+
%ta16 = and <8 x i16> %a, %mask.splat
16+
%tb16 = and <8 x i16> %b, %mask.splat
17+
%ta8 = trunc <8 x i16> %ta16 to <8 x i8>
18+
%tb8 = trunc <8 x i16> %tb16 to <8 x i8>
19+
%res = call <8 x i8> @llvm.aarch64.neon.uhadd.v8i8(<8 x i8> %ta8, <8 x i8> %tb8)
920
ret <8 x i8> %res
1021
}
1122

12-
; CHECK-LABEL: test_avgceil_s
13-
; CHECK: shadd v0.8b, v0.8b, v1.8b
23+
1424
define <8 x i8> @test_avgceil_s(<8 x i16> %a, <8 x i16> %b) {
15-
%ta = trunc <8 x i16> %a to <8 x i8>
16-
%tb = trunc <8 x i16> %b to <8 x i8>
17-
%res = call <8 x i8> @llvm.aarch64.neon.shadd.v8i8(<8 x i8> %ta, <8 x i8> %tb)
25+
; CHECK-LABEL: test_avgceil_s:
26+
; CHECK: // %bb.0:
27+
; CHECK-NEXT: sqxtn v0.8b, v0.8h
28+
; CHECK-NEXT: sqxtn v1.8b, v1.8h
29+
; CHECK-NEXT: shadd v0.8b, v0.8b, v1.8b
30+
; CHECK-NEXT: ret
31+
32+
%smin = insertelement <8 x i16> undef, i16 -128, i32 0
33+
%smax = insertelement <8 x i16> undef, i16 127, i32 0
34+
%min = shufflevector <8 x i16> %smin, <8 x i16> undef, <8 x i32> zeroinitializer
35+
%max = shufflevector <8 x i16> %smax, <8 x i16> undef, <8 x i32> zeroinitializer
36+
37+
%ta16 = call <8 x i16> @llvm.smin.v8i16(<8 x i16> %a, <8 x i16> %max)
38+
%ta16.clamped = call <8 x i16> @llvm.smax.v8i16(<8 x i16> %ta16, <8 x i16> %min)
39+
%tb16 = call <8 x i16> @llvm.smin.v8i16(<8 x i16> %b, <8 x i16> %max)
40+
%tb16.clamped = call <8 x i16> @llvm.smax.v8i16(<8 x i16> %tb16, <8 x i16> %min)
41+
42+
%ta8 = trunc <8 x i16> %ta16.clamped to <8 x i8>
43+
%tb8 = trunc <8 x i16> %tb16.clamped to <8 x i8>
44+
%res = call <8 x i8> @llvm.aarch64.neon.shadd.v8i8(<8 x i8> %ta8, <8 x i8> %tb8)
1845
ret <8 x i8> %res
1946
}
2047

21-
; CHECK-LABEL: test_avgfloor_u
22-
; CHECK: urhadd v0.8b, v0.8b, v1.8b
48+
2349
define <8 x i8> @test_avgfloor_u(<8 x i16> %a, <8 x i16> %b) {
24-
%ta = trunc <8 x i16> %a to <8 x i8>
25-
%tb = trunc <8 x i16> %b to <8 x i8>
26-
%res = call <8 x i8> @llvm.aarch64.neon.urhadd.v8i8(<8 x i8> %ta, <8 x i8> %tb)
50+
; CHECK-LABEL: test_avgfloor_u:
51+
; CHECK: // %bb.0:
52+
; CHECK-NEXT: xtn v0.8b, v0.8h
53+
; CHECK-NEXT: xtn v1.8b, v1.8h
54+
; CHECK-NEXT: urhadd v0.8b, v0.8b, v1.8b
55+
; CHECK-NEXT: ret
56+
57+
%mask = insertelement <8 x i16> undef, i16 255, i32 0
58+
%mask.splat = shufflevector <8 x i16> %mask, <8 x i16> undef, <8 x i32> zeroinitializer
59+
%ta16 = and <8 x i16> %a, %mask.splat
60+
%tb16 = and <8 x i16> %b, %mask.splat
61+
%ta8 = trunc <8 x i16> %ta16 to <8 x i8>
62+
%tb8 = trunc <8 x i16> %tb16 to <8 x i8>
63+
%res = call <8 x i8> @llvm.aarch64.neon.urhadd.v8i8(<8 x i8> %ta8, <8 x i8> %tb8)
2764
ret <8 x i8> %res
2865
}
2966

30-
; CHECK-LABEL: test_avgfloor_s
31-
; CHECK: srhadd v0.8b, v0.8b, v1.8b
67+
3268
define <8 x i8> @test_avgfloor_s(<8 x i16> %a, <8 x i16> %b) {
33-
%ta = trunc <8 x i16> %a to <8 x i8>
34-
%tb = trunc <8 x i16> %b to <8 x i8>
35-
%res = call <8 x i8> @llvm.aarch64.neon.srhadd.v8i8(<8 x i8> %ta, <8 x i8> %tb)
69+
; CHECK-LABEL: test_avgfloor_s:
70+
; CHECK: // %bb.0:
71+
; CHECK-NEXT: sqxtn v0.8b, v0.8h
72+
; CHECK-NEXT: sqxtn v1.8b, v1.8h
73+
; CHECK-NEXT: srhadd v0.8b, v0.8b, v1.8b
74+
; CHECK-NEXT: ret
75+
76+
%smin = insertelement <8 x i16> undef, i16 -128, i32 0
77+
%smax = insertelement <8 x i16> undef, i16 127, i32 0
78+
%min = shufflevector <8 x i16> %smin, <8 x i16> undef, <8 x i32> zeroinitializer
79+
%max = shufflevector <8 x i16> %smax, <8 x i16> undef, <8 x i32> zeroinitializer
80+
81+
%ta16 = call <8 x i16> @llvm.smin.v8i16(<8 x i16> %a, <8 x i16> %max)
82+
%ta16.clamped = call <8 x i16> @llvm.smax.v8i16(<8 x i16> %ta16, <8 x i16> %min)
83+
%tb16 = call <8 x i16> @llvm.smin.v8i16(<8 x i16> %b, <8 x i16> %max)
84+
%tb16.clamped = call <8 x i16> @llvm.smax.v8i16(<8 x i16> %tb16, <8 x i16> %min)
85+
86+
%ta8 = trunc <8 x i16> %ta16.clamped to <8 x i8>
87+
%tb8 = trunc <8 x i16> %tb16.clamped to <8 x i8>
88+
%res = call <8 x i8> @llvm.aarch64.neon.srhadd.v8i8(<8 x i8> %ta8, <8 x i8> %tb8)
3689
ret <8 x i8> %res
3790
}
3891

@@ -41,3 +94,6 @@ declare <8 x i8> @llvm.aarch64.neon.shadd.v8i8(<8 x i8>, <8 x i8>)
4194
declare <8 x i8> @llvm.aarch64.neon.urhadd.v8i8(<8 x i8>, <8 x i8>)
4295
declare <8 x i8> @llvm.aarch64.neon.srhadd.v8i8(<8 x i8>, <8 x i8>)
4396

97+
declare <8 x i16> @llvm.smin.v8i16(<8 x i16>, <8 x i16>)
98+
declare <8 x i16> @llvm.smax.v8i16(<8 x i16>, <8 x i16>)
99+

0 commit comments

Comments
 (0)