Skip to content

Commit 4830e63

Browse files
[LLVM][CodeGen][AArch64] Improve lowering of boolean vector popcount operations. (#166401)
1 parent cbb9b0e commit 4830e63

File tree

2 files changed

+350
-1
lines changed

2 files changed

+350
-1
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
#include "llvm/CodeGen/MachineInstrBuilder.h"
5151
#include "llvm/CodeGen/MachineMemOperand.h"
5252
#include "llvm/CodeGen/MachineRegisterInfo.h"
53+
#include "llvm/CodeGen/SDPatternMatch.h"
5354
#include "llvm/CodeGen/SelectionDAG.h"
5455
#include "llvm/CodeGen/SelectionDAGNodes.h"
5556
#include "llvm/CodeGen/TargetCallingConv.h"
@@ -104,7 +105,6 @@
104105
#include <vector>
105106

106107
using namespace llvm;
107-
using namespace llvm::PatternMatch;
108108

109109
#define DEBUG_TYPE "aarch64-lower"
110110

@@ -1174,6 +1174,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
11741174

11751175
setTargetDAGCombine(ISD::SHL);
11761176
setTargetDAGCombine(ISD::VECTOR_DEINTERLEAVE);
1177+
setTargetDAGCombine(ISD::CTPOP);
11771178

11781179
// In case of strict alignment, avoid an excessive number of byte wide stores.
11791180
MaxStoresPerMemsetOptSize = 8;
@@ -17555,6 +17556,7 @@ bool AArch64TargetLowering::optimizeExtendOrTruncateConversion(
1755517556
// udot instruction.
1755617557
if (SrcWidth * 4 <= DstWidth) {
1755717558
if (all_of(I->users(), [&](auto *U) {
17559+
using namespace llvm::PatternMatch;
1755817560
auto *SingleUser = cast<Instruction>(&*U);
1755917561
if (match(SingleUser, m_c_Mul(m_Specific(I), m_SExt(m_Value()))))
1756017562
return true;
@@ -17826,6 +17828,7 @@ bool AArch64TargetLowering::lowerInterleavedLoad(
1782617828
// into shift / and masks. For the moment we do this just for uitofp (not
1782717829
// zext) to avoid issues with widening instructions.
1782817830
if (Shuffles.size() == 4 && all_of(Shuffles, [](ShuffleVectorInst *SI) {
17831+
using namespace llvm::PatternMatch;
1782917832
return SI->hasOneUse() && match(SI->user_back(), m_UIToFP(m_Value())) &&
1783017833
SI->getType()->getScalarSizeInBits() * 4 ==
1783117834
SI->user_back()->getType()->getScalarSizeInBits();
@@ -27842,6 +27845,35 @@ static SDValue performRNDRCombine(SDNode *N, SelectionDAG &DAG) {
2784227845
{A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL);
2784327846
}
2784427847

27848+
static SDValue performCTPOPCombine(SDNode *N,
27849+
TargetLowering::DAGCombinerInfo &DCI,
27850+
SelectionDAG &DAG) {
27851+
using namespace llvm::SDPatternMatch;
27852+
if (!DCI.isBeforeLegalize())
27853+
return SDValue();
27854+
27855+
// ctpop(zext(bitcast(vector_mask))) -> neg(signed_reduce_add(vector_mask))
27856+
SDValue Mask;
27857+
if (!sd_match(N->getOperand(0), m_ZExt(m_BitCast(m_Value(Mask)))))
27858+
return SDValue();
27859+
27860+
EVT VT = N->getValueType(0);
27861+
EVT MaskVT = Mask.getValueType();
27862+
27863+
if (VT.isVector() || !MaskVT.isFixedLengthVector() ||
27864+
MaskVT.getVectorElementType() != MVT::i1)
27865+
return SDValue();
27866+
27867+
EVT ReduceInVT =
27868+
EVT::getVectorVT(*DAG.getContext(), VT, MaskVT.getVectorElementCount());
27869+
27870+
SDLoc DL(N);
27871+
// Sign extend to best fit ZeroOrNegativeOneBooleanContent.
27872+
SDValue ExtMask = DAG.getNode(ISD::SIGN_EXTEND, DL, ReduceInVT, Mask);
27873+
SDValue NegPopCount = DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, ExtMask);
27874+
return DAG.getNegative(NegPopCount, DL, VT);
27875+
}
27876+
2784527877
SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
2784627878
DAGCombinerInfo &DCI) const {
2784727879
SelectionDAG &DAG = DCI.DAG;
@@ -28187,6 +28219,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
2818728219
return performScalarToVectorCombine(N, DCI, DAG);
2818828220
case ISD::SHL:
2818928221
return performSHLCombine(N, DCI, DAG);
28222+
case ISD::CTPOP:
28223+
return performCTPOPCombine(N, DCI, DAG);
2819028224
}
2819128225
return SDValue();
2819228226
}

0 commit comments

Comments
 (0)