-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[DAG] Fold trunc(avg(x,y)) for avgceil/floor u/s nodes if they have sufficient leading zero/sign bits #152273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
80e303c
24287f7
c8cc2a9
1115256
08138a2
728b37d
44609a3
2d268fc
32041fb
4e1af14
c4ea7bd
6f84361
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -16294,6 +16294,43 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { | |||||||||
| // because targets may prefer a wider type during later combines and invert | ||||||||||
| // this transform. | ||||||||||
| switch (N0.getOpcode()) { | ||||||||||
| case ISD::AVGCEILU: | ||||||||||
| case ISD::AVGFLOORU: | ||||||||||
| if (!LegalOperations && N0.hasOneUse() && | ||||||||||
| TLI.isOperationLegal(N0.getOpcode(), VT)) { | ||||||||||
| SDValue X = N0.getOperand(0); | ||||||||||
| SDValue Y = N0.getOperand(1); | ||||||||||
| unsigned SrcBits = X.getScalarValueSizeInBits(); | ||||||||||
| unsigned DstBits = VT.getScalarSizeInBits(); | ||||||||||
| KnownBits KnownX = DAG.computeKnownBits(X); | ||||||||||
| KnownBits KnownY = DAG.computeKnownBits(Y); | ||||||||||
| if (KnownX.countMinLeadingZeros() >= (SrcBits - DstBits) && | ||||||||||
| KnownY.countMinLeadingZeros() >= (SrcBits - DstBits)) { | ||||||||||
|
||||||||||
| if (KnownX.countMinLeadingZeros() >= (SrcBits - DstBits) && | |
| KnownY.countMinLeadingZeros() >= (SrcBits - DstBits)) { | |
| if (KnownX.countMaxActiveBits() <= DstBits && | |
| KnownY.countMaxActiveBits() <= DstBits) { |
Then you don't need SrcBits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could even do this:
APInt UpperBits = APInt::getHighBitsSet(SrcBits, SrcBits - DstBits);
if (DAG.MaskedValueIsZero(X, UpperBits) &&
DAG.MaskedValueIsZero(Y, UpperBits)) {
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or APInt::getBitsSetFrom(SrcBits, DstBits). (Sometimes I think we have too many different helper functions!)
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NeededLeadingZeros = SrcBits - DstBits; ? (NeededSignBits is correct though you could use ComputeMaxSignificantBits instead if you wish)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I think you misunderstood - you need to use computeKnownBits.countMinLeadingZeros() >= (SrcBits - DstBits)
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (DAG.ComputeNumSignBits(X) >= NeededSignBits &&
DAG.ComputeNumSignBits(Y) >= NeededSignBits) {
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to reuse the ISD::ABD code later in the switch statement now - its has near-identical logic
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 | ||
| ; RUN: llc -mtriple=aarch64-- -O2 -mattr=+neon < %s | FileCheck %s | ||
RKSimon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| define <8 x i8> @test_avgceil_u(<8 x i16> %a, <8 x i16> %b) { | ||
| ; CHECK-LABEL: test_avgceil_u: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: bic v0.8h, #255, lsl #8 | ||
| ; CHECK-NEXT: bic v1.8h, #255, lsl #8 | ||
| ; CHECK-NEXT: uhadd v0.8h, v0.8h, v1.8h | ||
| ; CHECK-NEXT: xtn v0.8b, v0.8h | ||
| ; CHECK-NEXT: ret | ||
| %ta16 = and <8 x i16> %a, splat (i16 255) | ||
| %tb16 = and <8 x i16> %b, splat (i16 255) | ||
| %avg16 = call <8 x i16> @llvm.aarch64.neon.uhadd.v8i16(<8 x i16> %ta16, <8 x i16> %tb16) | ||
| %res = trunc <8 x i16> %avg16 to <8 x i8> | ||
| ret <8 x i8> %res | ||
| } | ||
RKSimon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| define <8 x i8> @test_avgceil_s(<8 x i16> %a, <8 x i16> %b) { | ||
| ; CHECK-LABEL: test_avgceil_s: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: movi v2.8h, #127 | ||
| ; CHECK-NEXT: mvni v3.8h, #127 | ||
| ; CHECK-NEXT: smin v0.8h, v0.8h, v2.8h | ||
| ; CHECK-NEXT: smin v1.8h, v1.8h, v2.8h | ||
| ; CHECK-NEXT: smax v0.8h, v0.8h, v3.8h | ||
| ; CHECK-NEXT: smax v1.8h, v1.8h, v3.8h | ||
| ; CHECK-NEXT: shadd v0.8h, v0.8h, v1.8h | ||
| ; CHECK-NEXT: xtn v0.8b, v0.8h | ||
| ; CHECK-NEXT: ret | ||
| %ta16 = call <8 x i16> @llvm.smin.v8i16(<8 x i16> %a, <8 x i16> splat (i16 127)) | ||
| %ta16.clamped = call <8 x i16> @llvm.smax.v8i16(<8 x i16> %ta16, <8 x i16> splat (i16 -128)) | ||
| %tb16 = call <8 x i16> @llvm.smin.v8i16(<8 x i16> %b, <8 x i16> splat (i16 127)) | ||
| %tb16.clamped = call <8 x i16> @llvm.smax.v8i16(<8 x i16> %tb16, <8 x i16> splat (i16 -128)) | ||
| %avg16 = call <8 x i16> @llvm.aarch64.neon.shadd.v8i16(<8 x i16> %ta16.clamped, <8 x i16> %tb16.clamped) | ||
| %res = trunc <8 x i16> %avg16 to <8 x i8> | ||
| ret <8 x i8> %res | ||
| } | ||
|
|
||
| define <8 x i8> @test_avgfloor_u(<8 x i16> %a, <8 x i16> %b) { | ||
| ; CHECK-LABEL: test_avgfloor_u: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: bic v0.8h, #255, lsl #8 | ||
| ; CHECK-NEXT: bic v1.8h, #255, lsl #8 | ||
| ; CHECK-NEXT: uhadd v0.8h, v0.8h, v1.8h | ||
| ; CHECK-NEXT: xtn v0.8b, v0.8h | ||
| ; CHECK-NEXT: ret | ||
| %ta16 = and <8 x i16> %a, splat (i16 255) | ||
| %tb16 = and <8 x i16> %b, splat (i16 255) | ||
| %avg16 = call <8 x i16> @llvm.aarch64.neon.uhadd.v8i16(<8 x i16> %ta16, <8 x i16> %tb16) | ||
| %res = trunc <8 x i16> %avg16 to <8 x i8> | ||
| ret <8 x i8> %res | ||
| } | ||
|
|
||
| define <8 x i8> @test_avgfloor_s(<8 x i16> %a, <8 x i16> %b) { | ||
| ; CHECK-LABEL: test_avgfloor_s: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: movi v2.8h, #127 | ||
| ; CHECK-NEXT: mvni v3.8h, #127 | ||
| ; CHECK-NEXT: smin v0.8h, v0.8h, v2.8h | ||
| ; CHECK-NEXT: smin v1.8h, v1.8h, v2.8h | ||
| ; CHECK-NEXT: smax v0.8h, v0.8h, v3.8h | ||
| ; CHECK-NEXT: smax v1.8h, v1.8h, v3.8h | ||
| ; CHECK-NEXT: shadd v0.8h, v0.8h, v1.8h | ||
| ; CHECK-NEXT: xtn v0.8b, v0.8h | ||
| ; CHECK-NEXT: ret | ||
| %ta16 = call <8 x i16> @llvm.smin.v8i16(<8 x i16> %a, <8 x i16> splat (i16 127)) | ||
| %ta16.clamped = call <8 x i16> @llvm.smax.v8i16(<8 x i16> %ta16, <8 x i16> splat (i16 -128)) | ||
| %tb16 = call <8 x i16> @llvm.smin.v8i16(<8 x i16> %b, <8 x i16> splat (i16 127)) | ||
| %tb16.clamped = call <8 x i16> @llvm.smax.v8i16(<8 x i16> %tb16, <8 x i16> splat (i16 -128)) | ||
| %avg16 = call <8 x i16> @llvm.aarch64.neon.shadd.v8i16(<8 x i16> %ta16.clamped, <8 x i16> %tb16.clamped) | ||
| %res = trunc <8 x i16> %avg16 to <8 x i8> | ||
| ret <8 x i8> %res | ||
| } | ||
|
|
||
| declare <8 x i16> @llvm.aarch64.neon.uhadd.v8i16(<8 x i16>, <8 x i16>) | ||
| declare <8 x i16> @llvm.aarch64.neon.shadd.v8i16(<8 x i16>, <8 x i16>) | ||
| declare <8 x i16> @llvm.smin.v8i16(<8 x i16>, <8 x i16>) | ||
| declare <8 x i16> @llvm.smax.v8i16(<8 x i16>, <8 x i16>) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
computeKnownBits can be expensive. You should rearrange this code so that you only call computeKnownBits(Y) if the test on the result of computeKnownBits(X) succeeds.