Skip to content

Commit 6ee5b3e

Browse files
committed
[X86] AVX512 optimised CTLZ/CTTZ implementations for i256/i512 scalars
Make use of AVX512 VPLZCNT/VPOPCNT to perform the big integer bit counts per vector element and then use VPCOMPRESS to extract the first non-zero element result
1 parent b4c4013 commit 6ee5b3e

File tree

2 files changed

+220
-210
lines changed

2 files changed

+220
-210
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2654,6 +2654,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
26542654
ISD::AVGCEILU,
26552655
ISD::AVGFLOORS,
26562656
ISD::AVGFLOORU,
2657+
ISD::CTLZ,
2658+
ISD::CTTZ,
2659+
ISD::CTLZ_ZERO_UNDEF,
2660+
ISD::CTTZ_ZERO_UNDEF,
26572661
ISD::BITREVERSE,
26582662
ISD::ADD,
26592663
ISD::FADD,
@@ -55162,6 +55166,61 @@ static SDValue combineXor(SDNode *N, SelectionDAG &DAG,
5516255166
return combineFneg(N, DAG, DCI, Subtarget);
5516355167
}
5516455168

55169+
// Fold i256/i512 CTLZ/CTTZ patterns to make use of AVX512
55170+
// vXi64 CTLZ/CTTZ and VECTOR_COMPRESS.
55171+
// Compute the CTLZ/CTTZ of each element, add the element's bit offset, compress
55172+
// the result to remove all zero elements (passthru is set to scalar bitwidth if
55173+
// all elements are zero) and extract the lowest compressed element.
55174+
static SDValue combineCTZ(SDNode *N, SelectionDAG &DAG,
55175+
TargetLowering::DAGCombinerInfo &DCI,
55176+
const X86Subtarget &Subtarget) {
55177+
EVT VT = N->getValueType(0);
55178+
SDValue N0 = N->getOperand(0);
55179+
unsigned Opc = N->getOpcode();
55180+
unsigned SizeInBits = VT.getSizeInBits();
55181+
assert((Opc == ISD::CTLZ || Opc == ISD::CTLZ_ZERO_UNDEF || Opc == ISD::CTTZ ||
55182+
Opc == ISD::CTTZ_ZERO_UNDEF) &&
55183+
"Unsupported bit count");
55184+
55185+
if (VT.isScalarInteger() && Subtarget.hasCDI() &&
55186+
((SizeInBits == 512 && Subtarget.useAVX512Regs()) ||
55187+
(SizeInBits == 256 && Subtarget.hasVLX() &&
55188+
X86::mayFoldLoad(N0, Subtarget)))) {
55189+
MVT VecVT = MVT::getVectorVT(MVT::i64, SizeInBits / 64);
55190+
MVT BoolVT = VecVT.changeVectorElementType(MVT::i1);
55191+
SDValue Vec = DAG.getBitcast(VecVT, N0);
55192+
SDLoc DL(N);
55193+
55194+
SmallVector<int, 8> RevMask;
55195+
SmallVector<SDValue, 8> Offsets;
55196+
for (unsigned I = 0, E = VecVT.getVectorNumElements(); I != E; ++I) {
55197+
RevMask.push_back((int)((E - 1) - I));
55198+
Offsets.push_back(DAG.getConstant(I * 64, DL, MVT::i64));
55199+
}
55200+
55201+
// CTLZ - reverse the elements as we want the top non-zero element.
55202+
if (Opc == ISD::CTLZ)
55203+
Vec = DAG.getVectorShuffle(VecVT, DL, Vec, Vec, RevMask);
55204+
55205+
SDValue PassThrough = DAG.getUNDEF(VecVT);
55206+
if (Opc == ISD::CTLZ || Opc == ISD::CTTZ)
55207+
PassThrough = DAG.getConstant(SizeInBits, DL, VecVT);
55208+
55209+
SDValue IsNonZero = DAG.getSetCC(DL, BoolVT, Vec,
55210+
DAG.getConstant(0, DL, VecVT), ISD::SETNE);
55211+
SDValue Cnt = DAG.getNode(Opc, DL, VecVT, Vec);
55212+
Cnt = DAG.getNode(ISD::ADD, DL, VecVT, Cnt,
55213+
DAG.getBuildVector(VecVT, DL, Offsets));
55214+
Cnt = DAG.getNode(ISD::VECTOR_COMPRESS, DL, VecVT, Cnt, IsNonZero,
55215+
PassThrough);
55216+
Cnt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cnt,
55217+
DAG.getVectorIdxConstant(0, DL));
55218+
return DAG.getZExtOrTrunc(Cnt, DL, VT);
55219+
}
55220+
55221+
return SDValue();
55222+
}
55223+
5516555224
static SDValue combineBITREVERSE(SDNode *N, SelectionDAG &DAG,
5516655225
TargetLowering::DAGCombinerInfo &DCI,
5516755226
const X86Subtarget &Subtarget) {
@@ -60885,6 +60944,10 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
6088560944
case ISD::AND: return combineAnd(N, DAG, DCI, Subtarget);
6088660945
case ISD::OR: return combineOr(N, DAG, DCI, Subtarget);
6088760946
case ISD::XOR: return combineXor(N, DAG, DCI, Subtarget);
60947+
case ISD::CTLZ:
60948+
case ISD::CTTZ:
60949+
case ISD::CTLZ_ZERO_UNDEF:
60950+
case ISD::CTTZ_ZERO_UNDEF:return combineCTZ(N, DAG, DCI, Subtarget);
6088860951
case ISD::BITREVERSE: return combineBITREVERSE(N, DAG, DCI, Subtarget);
6088960952
case ISD::AVGCEILS:
6089060953
case ISD::AVGCEILU:

0 commit comments

Comments
 (0)