Skip to content

Commit d50b01b

Browse files
committed
[AArch64] Transform add(x, abs(y)) -> saba(x, y, 0)
1 parent 5e31eef commit d50b01b

File tree

3 files changed

+226
-147
lines changed

3 files changed

+226
-147
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
#include "llvm/CodeGen/MachineInstrBuilder.h"
5151
#include "llvm/CodeGen/MachineMemOperand.h"
5252
#include "llvm/CodeGen/MachineRegisterInfo.h"
53+
#include "llvm/CodeGen/SDPatternMatch.h"
5354
#include "llvm/CodeGen/SelectionDAG.h"
5455
#include "llvm/CodeGen/SelectionDAGNodes.h"
5556
#include "llvm/CodeGen/TargetCallingConv.h"
@@ -21914,6 +21915,56 @@ static SDValue performExtBinopLoadFold(SDNode *N, SelectionDAG &DAG) {
2191421915
return DAG.getNode(N->getOpcode(), DL, VT, Ext0, NShift);
2191521916
}
2191621917

21918+
// Transform the following:
21919+
// - add(x, abs(y)) -> saba(x, y, 0)
21920+
// - add(x, zext(abs(y))) -> sabal(x, y, 0)
21921+
static SDValue performAddSABACombine(SDNode *N,
21922+
TargetLowering::DAGCombinerInfo &DCI) {
21923+
if (N->getOpcode() != ISD::ADD)
21924+
return SDValue();
21925+
21926+
EVT VT = N->getValueType(0);
21927+
if (!VT.isFixedLengthVector())
21928+
return SDValue();
21929+
21930+
SDValue N0 = N->getOperand(0);
21931+
SDValue N1 = N->getOperand(1);
21932+
21933+
auto MatchAbsOrZExtAbs = [](SDValue V0, SDValue V1, SDValue &AbsOp,
21934+
SDValue &Other, bool &IsZExt) {
21935+
Other = V1;
21936+
if (sd_match(V0, m_Abs(SDPatternMatch::m_Value(AbsOp)))) {
21937+
IsZExt = false;
21938+
return true;
21939+
}
21940+
if (sd_match(V0, SDPatternMatch::m_ZExt(
21941+
m_Abs(SDPatternMatch::m_Value(AbsOp))))) {
21942+
IsZExt = true;
21943+
return true;
21944+
}
21945+
21946+
return false;
21947+
};
21948+
21949+
SDValue AbsOp;
21950+
SDValue Other;
21951+
bool IsZExt;
21952+
if (!MatchAbsOrZExtAbs(N0, N1, AbsOp, Other, IsZExt) &&
21953+
!MatchAbsOrZExtAbs(N1, N0, AbsOp, Other, IsZExt))
21954+
return SDValue();
21955+
21956+
// Don't perform this on abs(sub), as this will become an ABD/ABA anyway.
21957+
if (AbsOp.getOpcode() == ISD::SUB)
21958+
return SDValue();
21959+
21960+
SDLoc DL(N);
21961+
SDValue Zero = DCI.DAG.getConstant(0, DL, MVT::i64);
21962+
SDValue Zeros = DCI.DAG.getSplatVector(AbsOp.getValueType(), DL, Zero);
21963+
21964+
unsigned Opcode = IsZExt ? AArch64ISD::SABAL : AArch64ISD::SABA;
21965+
return DCI.DAG.getNode(Opcode, DL, VT, Other, AbsOp, Zeros);
21966+
}
21967+
2191721968
static SDValue performAddSubCombine(SDNode *N,
2191821969
TargetLowering::DAGCombinerInfo &DCI) {
2191921970
// Try to change sum of two reductions.
@@ -21939,6 +21990,9 @@ static SDValue performAddSubCombine(SDNode *N,
2193921990
if (SDValue Val = performExtBinopLoadFold(N, DCI.DAG))
2194021991
return Val;
2194121992

21993+
if (SDValue Val = performAddSABACombine(N, DCI))
21994+
return Val;
21995+
2194221996
return performAddSubLongCombine(N, DCI);
2194321997
}
2194421998

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,10 @@ def AArch64sdot : SDNode<"AArch64ISD::SDOT", SDT_AArch64Dot>;
10591059
def AArch64udot : SDNode<"AArch64ISD::UDOT", SDT_AArch64Dot>;
10601060
def AArch64usdot : SDNode<"AArch64ISD::USDOT", SDT_AArch64Dot>;
10611061

1062+
// saba/sabal
1063+
def AArch64neonsaba : SDNode<"AArch64ISD::SABA", SDT_AArch64trivec>;
1064+
def AArch64neonsabal : SDNode<"AArch64ISD::SABAL", SDT_AArch64Dot>;
1065+
10621066
// Vector across-lanes addition
10631067
// Only the lower result lane is defined.
10641068
def AArch64saddv : SDNode<"AArch64ISD::SADDV", SDT_AArch64UnaryVec>;
@@ -6121,6 +6125,19 @@ defm SQRDMLAH : SIMDThreeSameVectorSQRDMLxHTiedHS<1,0b10000,"sqrdmlah",
61216125
defm SQRDMLSH : SIMDThreeSameVectorSQRDMLxHTiedHS<1,0b10001,"sqrdmlsh",
61226126
int_aarch64_neon_sqrdmlsh>;
61236127

6128+
def : Pat<(AArch64neonsaba (v8i8 V64:$Rd), V64:$Rn, V64:$Rm),
6129+
(SABAv8i8 V64:$Rd, V64:$Rn, V64:$Rm)>;
6130+
def : Pat<(AArch64neonsaba (v4i16 V64:$Rd), V64:$Rn, V64:$Rm),
6131+
(SABAv4i16 V64:$Rd, V64:$Rn, V64:$Rm)>;
6132+
def : Pat<(AArch64neonsaba (v2i32 V64:$Rd), V64:$Rn, V64:$Rm),
6133+
(SABAv2i32 V64:$Rd, V64:$Rn, V64:$Rm)>;
6134+
def : Pat<(AArch64neonsaba (v16i8 V128:$Rd), V128:$Rn, V128:$Rm),
6135+
(SABAv16i8 V128:$Rd, V128:$Rn, V128:$Rm)>;
6136+
def : Pat<(AArch64neonsaba (v8i16 V128:$Rd), V128:$Rn, V128:$Rm),
6137+
(SABAv8i16 V128:$Rd, V128:$Rn, V128:$Rm)>;
6138+
def : Pat<(AArch64neonsaba (v4i32 V128:$Rd), V128:$Rn, V128:$Rm),
6139+
(SABAv4i32 V128:$Rd, V128:$Rn, V128:$Rm)>;
6140+
61246141
defm AND : SIMDLogicalThreeVector<0, 0b00, "and", and>;
61256142
defm BIC : SIMDLogicalThreeVector<0, 0b01, "bic",
61266143
BinOpFrag<(and node:$LHS, (vnot node:$RHS))> >;
@@ -7008,6 +7025,14 @@ defm : AddSubHNPatterns<ADDHNv2i64_v2i32, ADDHNv2i64_v4i32,
70087025
SUBHNv2i64_v2i32, SUBHNv2i64_v4i32,
70097026
v2i32, v2i64, 32>;
70107027

7028+
// Patterns for SABAL
7029+
def : Pat<(AArch64neonsabal (v8i16 V128:$Rd), (v8i8 V64:$Rn), (v8i8 V64:$Rm)),
7030+
(SABALv8i8_v8i16 V128:$Rd, V64:$Rn, V64:$Rm)>;
7031+
def : Pat<(AArch64neonsabal (v4i32 V128:$Rd), (v4i16 V64:$Rn), (v4i16 V64:$Rm)),
7032+
(SABALv4i16_v4i32 V128:$Rd, V64:$Rn, V64:$Rm)>;
7033+
def : Pat<(AArch64neonsabal (v2i64 V128:$Rd), (v2i32 V64:$Rn), (v2i32 V64:$Rm)),
7034+
(SABALv2i32_v2i64 V128:$Rd, V64:$Rn, V64:$Rm)>;
7035+
70117036
//----------------------------------------------------------------------------
70127037
// AdvSIMD bitwise extract from vector instruction.
70137038
//----------------------------------------------------------------------------

0 commit comments

Comments
 (0)