Skip to content

Commit c8cc2a9

Browse files
[DAG] Fold trunc(avg(x,y)) for avgceil/floor u/s nodes if they have sufficient leading zero/sign bits -3
1 parent 24287f7 commit c8cc2a9

File tree

2 files changed

+15
-13
lines changed

2 files changed

+15
-13
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16302,9 +16302,10 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
1630216302
SDValue Y = N0.getOperand(1);
1630316303
unsigned SrcBits = X.getScalarValueSizeInBits();
1630416304
unsigned DstBits = VT.getScalarSizeInBits();
16305-
unsigned MaxBitsX = DAG.ComputeMaxSignificantBits(X);
16306-
unsigned MaxBitsY = DAG.ComputeMaxSignificantBits(Y);
16307-
if (MaxBitsX <= DstBits && MaxBitsY <= DstBits) {
16305+
KnownBits KnownX = DAG.computeKnownBits(X);
16306+
KnownBits KnownY = DAG.computeKnownBits(Y);
16307+
if (KnownX.countMinLeadingZeros() >= (SrcBits - DstBits) &&
16308+
KnownY.countMinLeadingZeros() >= (SrcBits - DstBits)) {
1630816309
SDValue Tx = DAG.getNode(ISD::TRUNCATE, DL, VT, X);
1630916310
SDValue Ty = DAG.getNode(ISD::TRUNCATE, DL, VT, Y);
1631016311
return DAG.getNode(N0.getOpcode(), DL, VT, Tx, Ty);
@@ -16322,6 +16323,7 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
1632216323
unsigned SrcBits = X.getScalarValueSizeInBits();
1632316324
unsigned DstBits = VT.getScalarSizeInBits();
1632416325
unsigned NeededSignBits = SrcBits - DstBits + 1;
16326+
1632516327
if (SignBitsX >= NeededSignBits && SignBitsY >= NeededSignBits) {
1632616328
SDValue Tx = DAG.getNode(ISD::TRUNCATE, DL, VT, X);
1632716329
SDValue Ty = DAG.getNode(ISD::TRUNCATE, DL, VT, Y);

llvm/test/CodeGen/AArch64/trunc-avg-fold.ll

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ define <8 x i8> @test_avgceil_u(<8 x i16> %a, <8 x i16> %b) {
1010
; CHECK-NEXT: uhadd v0.8b, v0.8b, v1.8b
1111
; CHECK-NEXT: ret
1212

13-
%mask = insertelement <8 x i16> undef, i16 255, i32 0
14-
%mask.splat = shufflevector <8 x i16> %mask, <8 x i16> undef, <8 x i32> zeroinitializer
13+
%mask = insertelement <8 x i16> poison, i16 255, i32 0
14+
%mask.splat = shufflevector <8 x i16> %mask, <8 x i16> poison, <8 x i32> zeroinitializer
1515
%ta16 = and <8 x i16> %a, %mask.splat
1616
%tb16 = and <8 x i16> %b, %mask.splat
1717
%ta8 = trunc <8 x i16> %ta16 to <8 x i8>
@@ -29,10 +29,10 @@ define <8 x i8> @test_avgceil_s(<8 x i16> %a, <8 x i16> %b) {
2929
; CHECK-NEXT: shadd v0.8b, v0.8b, v1.8b
3030
; CHECK-NEXT: ret
3131

32-
%smin = insertelement <8 x i16> undef, i16 -128, i32 0
33-
%smax = insertelement <8 x i16> undef, i16 127, i32 0
34-
%min = shufflevector <8 x i16> %smin, <8 x i16> undef, <8 x i32> zeroinitializer
35-
%max = shufflevector <8 x i16> %smax, <8 x i16> undef, <8 x i32> zeroinitializer
32+
%smin = insertelement <8 x i16> poison, i16 -128, i32 0
33+
%smax = insertelement <8 x i16> poison, i16 127, i32 0
34+
%min = shufflevector <8 x i16> %smin, <8 x i16> poison, <8 x i32> zeroinitializer
35+
%max = shufflevector <8 x i16> %smax, <8 x i16> poison, <8 x i32> zeroinitializer
3636

3737
%ta16 = call <8 x i16> @llvm.smin.v8i16(<8 x i16> %a, <8 x i16> %max)
3838
%ta16.clamped = call <8 x i16> @llvm.smax.v8i16(<8 x i16> %ta16, <8 x i16> %min)
@@ -73,10 +73,10 @@ define <8 x i8> @test_avgfloor_s(<8 x i16> %a, <8 x i16> %b) {
7373
; CHECK-NEXT: srhadd v0.8b, v0.8b, v1.8b
7474
; CHECK-NEXT: ret
7575

76-
%smin = insertelement <8 x i16> undef, i16 -128, i32 0
77-
%smax = insertelement <8 x i16> undef, i16 127, i32 0
78-
%min = shufflevector <8 x i16> %smin, <8 x i16> undef, <8 x i32> zeroinitializer
79-
%max = shufflevector <8 x i16> %smax, <8 x i16> undef, <8 x i32> zeroinitializer
76+
%smin = insertelement <8 x i16> poison, i16 -128, i32 0
77+
%smax = insertelement <8 x i16> poison, i16 127, i32 0
78+
%min = shufflevector <8 x i16> %smin, <8 x i16> poison, <8 x i32> zeroinitializer
79+
%max = shufflevector <8 x i16> %smax, <8 x i16> poison, <8 x i32> zeroinitializer
8080

8181
%ta16 = call <8 x i16> @llvm.smin.v8i16(<8 x i16> %a, <8 x i16> %max)
8282
%ta16.clamped = call <8 x i16> @llvm.smax.v8i16(<8 x i16> %ta16, <8 x i16> %min)

0 commit comments

Comments
 (0)