-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[DAG] Fold trunc(avg(x,y)) for avgceil/floor u/s nodes if they have sufficient leading zero/sign bits #152273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
80e303c
24287f7
c8cc2a9
1115256
08138a2
728b37d
44609a3
2d268fc
32041fb
4e1af14
c4ea7bd
6f84361
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16294,6 +16294,41 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { | |
| // because targets may prefer a wider type during later combines and invert | ||
| // this transform. | ||
| switch (N0.getOpcode()) { | ||
| case ISD::AVGCEILU: | ||
| case ISD::AVGFLOORU: | ||
| if (!LegalOperations && N0.hasOneUse() && | ||
| TLI.isOperationLegal(N0.getOpcode(), VT)) { | ||
| SDValue X = N0.getOperand(0); | ||
| SDValue Y = N0.getOperand(1); | ||
| unsigned SrcBits = X.getScalarValueSizeInBits(); | ||
| unsigned DstBits = VT.getScalarSizeInBits(); | ||
| unsigned MaxBitsX = DAG.ComputeMaxSignificantBits(X); | ||
| unsigned MaxBitsY = DAG.ComputeMaxSignificantBits(Y); | ||
| if (MaxBitsX <= DstBits && MaxBitsY <= DstBits) { | ||
| SDValue Tx = DAG.getNode(ISD::TRUNCATE, DL, VT, X); | ||
| SDValue Ty = DAG.getNode(ISD::TRUNCATE, DL, VT, Y); | ||
| return DAG.getNode(N0.getOpcode(), DL, VT, Tx, Ty); | ||
| } | ||
| } | ||
| break; | ||
| case ISD::AVGCEILS: | ||
| case ISD::AVGFLOORS: | ||
| if (!LegalOperations && N0.hasOneUse() && | ||
| TLI.isOperationLegal(N0.getOpcode(), VT)) { | ||
| SDValue X = N0.getOperand(0); | ||
| SDValue Y = N0.getOperand(1); | ||
| unsigned SignBitsX = DAG.ComputeNumSignBits(X); | ||
| unsigned SignBitsY = DAG.ComputeNumSignBits(Y); | ||
| unsigned SrcBits = X.getScalarValueSizeInBits(); | ||
| unsigned DstBits = VT.getScalarSizeInBits(); | ||
| unsigned NeededSignBits = SrcBits - DstBits + 1; | ||
| if (SignBitsX >= NeededSignBits && SignBitsY >= NeededSignBits) { | ||
| SDValue Tx = DAG.getNode(ISD::TRUNCATE, DL, VT, X); | ||
| SDValue Ty = DAG.getNode(ISD::TRUNCATE, DL, VT, Y); | ||
|
||
| return DAG.getNode(N0.getOpcode(), DL, VT, Tx, Ty); | ||
| } | ||
| } | ||
| break; | ||
| case ISD::ADD: | ||
| case ISD::SUB: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should be able to reuse the ISD::ABD code later in the switch statement now - its has near-identical logic |
||
| case ISD::MUL: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 | ||
| ; RUN: llc -mtriple=aarch64-- -O2 -mattr=+neon < %s | FileCheck %s | ||
RKSimon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| define <8 x i8> @test_avgceil_u(<8 x i16> %a, <8 x i16> %b) { | ||
| ; CHECK-LABEL: test_avgceil_u: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: xtn v0.8b, v0.8h | ||
| ; CHECK-NEXT: xtn v1.8b, v1.8h | ||
| ; CHECK-NEXT: uhadd v0.8b, v0.8b, v1.8b | ||
| ; CHECK-NEXT: ret | ||
|
|
||
| %mask = insertelement <8 x i16> undef, i16 255, i32 0 | ||
| %mask.splat = shufflevector <8 x i16> %mask, <8 x i16> undef, <8 x i32> zeroinitializer | ||
| %ta16 = and <8 x i16> %a, %mask.splat | ||
| %tb16 = and <8 x i16> %b, %mask.splat | ||
| %ta8 = trunc <8 x i16> %ta16 to <8 x i8> | ||
| %tb8 = trunc <8 x i16> %tb16 to <8 x i8> | ||
| %res = call <8 x i8> @llvm.aarch64.neon.uhadd.v8i8(<8 x i8> %ta8, <8 x i8> %tb8) | ||
| ret <8 x i8> %res | ||
| } | ||
RKSimon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| define <8 x i8> @test_avgceil_s(<8 x i16> %a, <8 x i16> %b) { | ||
| ; CHECK-LABEL: test_avgceil_s: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: sqxtn v0.8b, v0.8h | ||
| ; CHECK-NEXT: sqxtn v1.8b, v1.8h | ||
| ; CHECK-NEXT: shadd v0.8b, v0.8b, v1.8b | ||
| ; CHECK-NEXT: ret | ||
|
|
||
| %smin = insertelement <8 x i16> undef, i16 -128, i32 0 | ||
| %smax = insertelement <8 x i16> undef, i16 127, i32 0 | ||
| %min = shufflevector <8 x i16> %smin, <8 x i16> undef, <8 x i32> zeroinitializer | ||
| %max = shufflevector <8 x i16> %smax, <8 x i16> undef, <8 x i32> zeroinitializer | ||
|
|
||
| %ta16 = call <8 x i16> @llvm.smin.v8i16(<8 x i16> %a, <8 x i16> %max) | ||
| %ta16.clamped = call <8 x i16> @llvm.smax.v8i16(<8 x i16> %ta16, <8 x i16> %min) | ||
| %tb16 = call <8 x i16> @llvm.smin.v8i16(<8 x i16> %b, <8 x i16> %max) | ||
| %tb16.clamped = call <8 x i16> @llvm.smax.v8i16(<8 x i16> %tb16, <8 x i16> %min) | ||
|
|
||
| %ta8 = trunc <8 x i16> %ta16.clamped to <8 x i8> | ||
| %tb8 = trunc <8 x i16> %tb16.clamped to <8 x i8> | ||
| %res = call <8 x i8> @llvm.aarch64.neon.shadd.v8i8(<8 x i8> %ta8, <8 x i8> %tb8) | ||
| ret <8 x i8> %res | ||
| } | ||
|
|
||
|
|
||
| define <8 x i8> @test_avgfloor_u(<8 x i16> %a, <8 x i16> %b) { | ||
| ; CHECK-LABEL: test_avgfloor_u: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: xtn v0.8b, v0.8h | ||
| ; CHECK-NEXT: xtn v1.8b, v1.8h | ||
| ; CHECK-NEXT: urhadd v0.8b, v0.8b, v1.8b | ||
| ; CHECK-NEXT: ret | ||
|
|
||
| %mask = insertelement <8 x i16> undef, i16 255, i32 0 | ||
| %mask.splat = shufflevector <8 x i16> %mask, <8 x i16> undef, <8 x i32> zeroinitializer | ||
| %ta16 = and <8 x i16> %a, %mask.splat | ||
| %tb16 = and <8 x i16> %b, %mask.splat | ||
|
||
| %ta8 = trunc <8 x i16> %ta16 to <8 x i8> | ||
| %tb8 = trunc <8 x i16> %tb16 to <8 x i8> | ||
| %res = call <8 x i8> @llvm.aarch64.neon.urhadd.v8i8(<8 x i8> %ta8, <8 x i8> %tb8) | ||
| ret <8 x i8> %res | ||
| } | ||
|
|
||
|
|
||
| define <8 x i8> @test_avgfloor_s(<8 x i16> %a, <8 x i16> %b) { | ||
| ; CHECK-LABEL: test_avgfloor_s: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: sqxtn v0.8b, v0.8h | ||
| ; CHECK-NEXT: sqxtn v1.8b, v1.8h | ||
| ; CHECK-NEXT: srhadd v0.8b, v0.8b, v1.8b | ||
| ; CHECK-NEXT: ret | ||
|
|
||
| %smin = insertelement <8 x i16> undef, i16 -128, i32 0 | ||
| %smax = insertelement <8 x i16> undef, i16 127, i32 0 | ||
| %min = shufflevector <8 x i16> %smin, <8 x i16> undef, <8 x i32> zeroinitializer | ||
| %max = shufflevector <8 x i16> %smax, <8 x i16> undef, <8 x i32> zeroinitializer | ||
|
|
||
| %ta16 = call <8 x i16> @llvm.smin.v8i16(<8 x i16> %a, <8 x i16> %max) | ||
| %ta16.clamped = call <8 x i16> @llvm.smax.v8i16(<8 x i16> %ta16, <8 x i16> %min) | ||
| %tb16 = call <8 x i16> @llvm.smin.v8i16(<8 x i16> %b, <8 x i16> %max) | ||
| %tb16.clamped = call <8 x i16> @llvm.smax.v8i16(<8 x i16> %tb16, <8 x i16> %min) | ||
|
|
||
| %ta8 = trunc <8 x i16> %ta16.clamped to <8 x i8> | ||
| %tb8 = trunc <8 x i16> %tb16.clamped to <8 x i8> | ||
| %res = call <8 x i8> @llvm.aarch64.neon.srhadd.v8i8(<8 x i8> %ta8, <8 x i8> %tb8) | ||
| ret <8 x i8> %res | ||
| } | ||
|
|
||
| declare <8 x i8> @llvm.aarch64.neon.uhadd.v8i8(<8 x i8>, <8 x i8>) | ||
| declare <8 x i8> @llvm.aarch64.neon.shadd.v8i8(<8 x i8>, <8 x i8>) | ||
| declare <8 x i8> @llvm.aarch64.neon.urhadd.v8i8(<8 x i8>, <8 x i8>) | ||
| declare <8 x i8> @llvm.aarch64.neon.srhadd.v8i8(<8 x i8>, <8 x i8>) | ||
|
|
||
| declare <8 x i16> @llvm.smin.v8i16(<8 x i16>, <8 x i16>) | ||
| declare <8 x i16> @llvm.smax.v8i16(<8 x i16>, <8 x i16>) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NeededLeadingZeros = SrcBits - DstBits;? (NeededSignBits is correct though you could use ComputeMaxSignificantBits instead if you wish)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I think you misunderstood - you need to use
computeKnownBits.countMinLeadingZeros() >= (SrcBits - DstBits)