-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[AArch64] Combine vector add(trunc(shift)) #169523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-backend-aarch64 Author: David Green (davemgreen) ChangesThis adds a combine for The original converts into ashr+lshr+xtn+xtn+add. The second becomes Full diff: https://github.com/llvm/llvm-project/pull/169523.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e91f5a877b35b..8026b5e542f27 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -105,6 +105,7 @@
#include <vector>
using namespace llvm;
+using namespace llvm::SDPatternMatch;
#define DEBUG_TYPE "aarch64-lower"
@@ -22595,6 +22596,37 @@ static SDValue performSubWithBorrowCombine(SDNode *N, SelectionDAG &DAG) {
Flags);
}
+// add(trunc(ashr(A, C)), trunc(lshr(B, BW-1))), with C >= BW
+// ->
+// X = trunc(ashr(A, C)); add(x, lshr(X, BW-1)
+// The original converts into ashr+lshr+xtn+xtn+add. The second becomes
+// ashr+xtn+usra. The first form has less total latency due to more parallelism,
+// but more micro-ops and seems to be slower in practice.
+static SDValue performAddTrunkShiftCombine(SDNode *N, SelectionDAG &DAG) {
+ EVT VT = N->getValueType(0);
+ if (VT != MVT::v2i32 && VT != MVT::v4i16 && VT != MVT::v8i8)
+ return SDValue();
+
+ SDValue AShr, LShr;
+ if (!sd_match(N, m_Add(m_Trunc(m_Value(AShr)), m_Trunc(m_Value(LShr)))))
+ return SDValue();
+ if (AShr.getOpcode() != AArch64ISD::VASHR)
+ std::swap(AShr, LShr);
+ if (AShr.getOpcode() != AArch64ISD::VASHR ||
+ LShr.getOpcode() != AArch64ISD::VLSHR ||
+ AShr.getOperand(0) != LShr.getOperand(0) ||
+ AShr.getConstantOperandVal(1) < VT.getScalarSizeInBits() ||
+ LShr.getConstantOperandVal(1) != VT.getScalarSizeInBits() * 2 - 1)
+ return SDValue();
+
+ SDLoc DL(N);
+ SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, AShr);
+ SDValue Shift = DAG.getNode(
+ AArch64ISD::VLSHR, DL, VT, Trunc,
+ DAG.getTargetConstant(VT.getScalarSizeInBits() - 1, DL, MVT::i32));
+ return DAG.getNode(ISD::ADD, DL, VT, Trunc, Shift);
+}
+
static SDValue performAddSubCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
// Try to change sum of two reductions.
@@ -22618,6 +22650,8 @@ static SDValue performAddSubCombine(SDNode *N,
return Val;
if (SDValue Val = performSubWithBorrowCombine(N, DCI.DAG))
return Val;
+ if (SDValue Val = performAddTrunkShiftCombine(N, DCI.DAG))
+ return Val;
if (SDValue Val = performExtBinopLoadFold(N, DCI.DAG))
return Val;
@@ -28125,7 +28159,6 @@ static SDValue performRNDRCombine(SDNode *N, SelectionDAG &DAG) {
static SDValue performCTPOPCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
- using namespace llvm::SDPatternMatch;
if (!DCI.isBeforeLegalize())
return SDValue();
diff --git a/llvm/test/CodeGen/AArch64/addtruncshift.ll b/llvm/test/CodeGen/AArch64/addtruncshift.ll
new file mode 100644
index 0000000000000..6dbe0b3d80b9a
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/addtruncshift.ll
@@ -0,0 +1,114 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-none-elf < %s | FileCheck %s --check-prefixes=CHECK,CHECK-SD
+; RUN: llc -mtriple=aarch64-none-elf -global-isel < %s | FileCheck %s --check-prefixes=CHECK,CHECK-GI
+
+define <2 x i32> @test_v2i64(<2 x i64> %n) {
+; CHECK-SD-LABEL: test_v2i64:
+; CHECK-SD: // %bb.0: // %entry
+; CHECK-SD-NEXT: sshr v0.2d, v0.2d, #35
+; CHECK-SD-NEXT: xtn v0.2s, v0.2d
+; CHECK-SD-NEXT: usra v0.2s, v0.2s, #31
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: test_v2i64:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: ushr v1.2d, v0.2d, #63
+; CHECK-GI-NEXT: sshr v0.2d, v0.2d, #35
+; CHECK-GI-NEXT: xtn v1.2s, v1.2d
+; CHECK-GI-NEXT: xtn v0.2s, v0.2d
+; CHECK-GI-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-GI-NEXT: ret
+entry:
+ %shr = lshr <2 x i64> %n, splat (i64 63)
+ %vmovn.i4 = trunc nuw nsw <2 x i64> %shr to <2 x i32>
+ %shr1 = ashr <2 x i64> %n, splat (i64 35)
+ %vmovn.i = trunc nsw <2 x i64> %shr1 to <2 x i32>
+ %add = add nsw <2 x i32> %vmovn.i4, %vmovn.i
+ ret <2 x i32> %add
+}
+
+define <4 x i16> @test_v4i32(<4 x i32> %n) {
+; CHECK-SD-LABEL: test_v4i32:
+; CHECK-SD: // %bb.0: // %entry
+; CHECK-SD-NEXT: sshr v0.4s, v0.4s, #17
+; CHECK-SD-NEXT: xtn v0.4h, v0.4s
+; CHECK-SD-NEXT: usra v0.4h, v0.4h, #15
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: test_v4i32:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: ushr v1.4s, v0.4s, #31
+; CHECK-GI-NEXT: sshr v0.4s, v0.4s, #17
+; CHECK-GI-NEXT: xtn v1.4h, v1.4s
+; CHECK-GI-NEXT: xtn v0.4h, v0.4s
+; CHECK-GI-NEXT: add v0.4h, v1.4h, v0.4h
+; CHECK-GI-NEXT: ret
+entry:
+ %shr = lshr <4 x i32> %n, splat (i32 31)
+ %vmovn.i4 = trunc nuw nsw <4 x i32> %shr to <4 x i16>
+ %shr1 = ashr <4 x i32> %n, splat (i32 17)
+ %vmovn.i = trunc nsw <4 x i32> %shr1 to <4 x i16>
+ %add = add nsw <4 x i16> %vmovn.i4, %vmovn.i
+ ret <4 x i16> %add
+}
+
+define <8 x i8> @test_v8i16(<8 x i16> %n) {
+; CHECK-SD-LABEL: test_v8i16:
+; CHECK-SD: // %bb.0: // %entry
+; CHECK-SD-NEXT: sshr v0.8h, v0.8h, #9
+; CHECK-SD-NEXT: xtn v0.8b, v0.8h
+; CHECK-SD-NEXT: usra v0.8b, v0.8b, #7
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: test_v8i16:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: ushr v1.8h, v0.8h, #15
+; CHECK-GI-NEXT: sshr v0.8h, v0.8h, #9
+; CHECK-GI-NEXT: xtn v1.8b, v1.8h
+; CHECK-GI-NEXT: xtn v0.8b, v0.8h
+; CHECK-GI-NEXT: add v0.8b, v1.8b, v0.8b
+; CHECK-GI-NEXT: ret
+entry:
+ %shr = lshr <8 x i16> %n, splat (i16 15)
+ %vmovn.i4 = trunc nuw nsw <8 x i16> %shr to <8 x i8>
+ %shr1 = ashr <8 x i16> %n, splat (i16 9)
+ %vmovn.i = trunc nsw <8 x i16> %shr1 to <8 x i8>
+ %add = add nsw <8 x i8> %vmovn.i4, %vmovn.i
+ ret <8 x i8> %add
+}
+
+define <2 x i32> @test_v2i64_smallsrl(<2 x i64> %n) {
+; CHECK-LABEL: test_v2i64_smallsrl:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ushr v1.2d, v0.2d, #62
+; CHECK-NEXT: sshr v0.2d, v0.2d, #35
+; CHECK-NEXT: xtn v1.2s, v1.2d
+; CHECK-NEXT: xtn v0.2s, v0.2d
+; CHECK-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NEXT: ret
+entry:
+ %shr = lshr <2 x i64> %n, splat (i64 62)
+ %vmovn.i4 = trunc nuw nsw <2 x i64> %shr to <2 x i32>
+ %shr1 = ashr <2 x i64> %n, splat (i64 35)
+ %vmovn.i = trunc nsw <2 x i64> %shr1 to <2 x i32>
+ %add = add nsw <2 x i32> %vmovn.i4, %vmovn.i
+ ret <2 x i32> %add
+}
+
+define <2 x i32> @test_v2i64_smallsra(<2 x i64> %n) {
+; CHECK-LABEL: test_v2i64_smallsra:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ushr v1.2d, v0.2d, #63
+; CHECK-NEXT: shrn v0.2s, v0.2d, #27
+; CHECK-NEXT: xtn v1.2s, v1.2d
+; CHECK-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NEXT: ret
+entry:
+ %shr = lshr <2 x i64> %n, splat (i64 63)
+ %vmovn.i4 = trunc nuw nsw <2 x i64> %shr to <2 x i32>
+ %shr1 = ashr <2 x i64> %n, splat (i64 27)
+ %vmovn.i = trunc nsw <2 x i64> %shr1 to <2 x i32>
+ %add = add nsw <2 x i32> %vmovn.i4, %vmovn.i
+ ret <2 x i32> %add
+}
+
|
🐧 Linux x64 Test Results
All tests passed but another part of the build failed. Click on a failure below to see the details. lib/Target/AArch64/CMakeFiles/LLVMAArch64CodeGen.dir/AArch64ISelLowering.cpp.oIf these failures are unrelated to your changes (for example tests are broken or flaky at HEAD), please open an issue at https://github.com/llvm/llvm-project/issues and add the |
SamTebbs33
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| Flags); | ||
| } | ||
|
|
||
| // add(trunc(ashr(A, C)), trunc(lshr(B, BW-1))), with C >= BW |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should B here be A? That seems to be what the code suggests.
usha1830
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Thanks.
| // The original converts into ashr+lshr+xtn+xtn+add. The second becomes | ||
| // ashr+xtn+usra. The first form has less total latency due to more parallelism, | ||
| // but more micro-ops and seems to be slower in practice. | ||
| static SDValue performAddTrunkShiftCombine(SDNode *N, SelectionDAG &DAG) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| static SDValue performAddTrunkShiftCombine(SDNode *N, SelectionDAG &DAG) { | |
| static SDValue performAddTruncShiftCombine(SDNode *N, SelectionDAG &DAG) { |
1faa946 to
9ea52af
Compare
This adds a combine for add(trunc(ashr(A, C)), trunc(lshr(A, BW-1))), with C >= BW -> X = trunc(ashr(A, C)); add(x, lshr(X, BW-1) The original converts into ashr+lshr+xtn+xtn+add. The second becomes ashr+xtn+usra. The first form has less total latency due to more parallelism, but more micro-ops and seems to be slower in practice.
9ea52af to
d01e4cc
Compare
This adds a combine for add(trunc(ashr(A, C)), trunc(lshr(A, BW-1))), with C >= BW -> X = trunc(ashr(A, C)); add(x, lshr(X, BW-1) The original converts into ashr+lshr+xtn+xtn+add. The second becomes ashr+xtn+usra. The first form has less total latency due to more parallelism, but more micro-ops and seems to be slower in practice.
This adds a combine for add(trunc(ashr(A, C)), trunc(lshr(A, BW-1))), with C >= BW -> X = trunc(ashr(A, C)); add(x, lshr(X, BW-1) The original converts into ashr+lshr+xtn+xtn+add. The second becomes ashr+xtn+usra. The first form has less total latency due to more parallelism, but more micro-ops and seems to be slower in practice.
This adds a combine for
add(trunc(ashr(A, C)), trunc(lshr(A, BW-1))), with C >= BW
->
X = trunc(ashr(A, C)); add(x, lshr(X, BW-1)
The original converts into ashr+lshr+xtn+xtn+add. The second becomes
ashr+xtn+usra. The first form has less total latency due to more parallelism,
but more micro-ops and seems to be slower in practice.