Skip to content

Commit 5e2c7e7

Browse files
committed
[X86] AVX512 optimised CTLZ/CTTZ implementations for i256/i512 scalars
Make use of AVX512 VPLZCNT/VPOPCNT to perform the big integer bit counts per vector element and then use VPCOMPRESS to extract the first non-zero element result
1 parent b4c4013 commit 5e2c7e7

File tree

2 files changed

+210
-210
lines changed

2 files changed

+210
-210
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2654,6 +2654,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
26542654
ISD::AVGCEILU,
26552655
ISD::AVGFLOORS,
26562656
ISD::AVGFLOORU,
2657+
ISD::CTLZ,
2658+
ISD::CTTZ,
26572659
ISD::BITREVERSE,
26582660
ISD::ADD,
26592661
ISD::FADD,
@@ -55162,6 +55164,55 @@ static SDValue combineXor(SDNode *N, SelectionDAG &DAG,
5516255164
return combineFneg(N, DAG, DCI, Subtarget);
5516355165
}
5516455166

55167+
// Fold i256/i512 CTLZ/CTTZ patterns to make use of AVX512
55168+
// vXi64 CTLZ/CTTZ and VECTOR_COMPRESS.
55169+
// Compute the CTLZ/CTTZ of each element, add the element's bit offset, compress
55170+
// the result to remove all zero elements (passthru is set to scalar bitwidth if
55171+
// all elements are zero) and extract the lowest compressed element.
55172+
static SDValue combineCTZ(SDNode *N, SelectionDAG &DAG,
55173+
TargetLowering::DAGCombinerInfo &DCI,
55174+
const X86Subtarget &Subtarget) {
55175+
EVT VT = N->getValueType(0);
55176+
SDValue N0 = N->getOperand(0);
55177+
unsigned Opc = N->getOpcode();
55178+
unsigned SizeInBits = VT.getSizeInBits();
55179+
assert((Opc == ISD::CTLZ || Opc == ISD::CTTZ) && "Unsupported bit count");
55180+
55181+
if (VT.isScalarInteger() && Subtarget.hasCDI() &&
55182+
((SizeInBits == 512 && Subtarget.useAVX512Regs()) ||
55183+
(SizeInBits == 256 && Subtarget.hasVLX() &&
55184+
X86::mayFoldLoad(N0, Subtarget)))) {
55185+
MVT VecVT = MVT::getVectorVT(MVT::i64, SizeInBits / 64);
55186+
MVT BoolVT = VecVT.changeVectorElementType(MVT::i1);
55187+
SDValue Vec = DAG.getBitcast(VecVT, N0);
55188+
SDLoc DL(N0);
55189+
55190+
SmallVector<int, 8> RevMask;
55191+
SmallVector<SDValue, 8> Offsets;
55192+
for (unsigned I = 0, E = VecVT.getVectorNumElements(); I != E; ++I) {
55193+
RevMask.push_back((int)((E - 1) - I));
55194+
Offsets.push_back(DAG.getConstant(I * 64, DL, MVT::i64));
55195+
}
55196+
55197+
// CTLZ - reverse the elements as we want the top non-zero element.
55198+
if (Opc == ISD::CTLZ)
55199+
Vec = DAG.getVectorShuffle(VecVT, DL, Vec, Vec, RevMask);
55200+
55201+
SDValue IsNonZero = DAG.getSetCC(DL, BoolVT, Vec,
55202+
DAG.getConstant(0, DL, VecVT), ISD::SETNE);
55203+
SDValue Cnt = DAG.getNode(Opc, DL, VecVT, Vec);
55204+
Cnt = DAG.getNode(ISD::ADD, DL, VecVT, Cnt,
55205+
DAG.getBuildVector(VecVT, DL, Offsets));
55206+
Cnt = DAG.getNode(ISD::VECTOR_COMPRESS, DL, VecVT, Cnt, IsNonZero,
55207+
DAG.getConstant(SizeInBits, DL, VecVT));
55208+
Cnt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cnt,
55209+
DAG.getVectorIdxConstant(0, DL));
55210+
return DAG.getZExtOrTrunc(Cnt, DL, VT);
55211+
}
55212+
55213+
return SDValue();
55214+
}
55215+
5516555216
static SDValue combineBITREVERSE(SDNode *N, SelectionDAG &DAG,
5516655217
TargetLowering::DAGCombinerInfo &DCI,
5516755218
const X86Subtarget &Subtarget) {
@@ -60885,6 +60936,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
6088560936
case ISD::AND: return combineAnd(N, DAG, DCI, Subtarget);
6088660937
case ISD::OR: return combineOr(N, DAG, DCI, Subtarget);
6088760938
case ISD::XOR: return combineXor(N, DAG, DCI, Subtarget);
60939+
case ISD::CTLZ:
60940+
case ISD::CTTZ: return combineCTZ(N, DAG, DCI, Subtarget);
6088860941
case ISD::BITREVERSE: return combineBITREVERSE(N, DAG, DCI, Subtarget);
6088960942
case ISD::AVGCEILS:
6089060943
case ISD::AVGCEILU:

0 commit comments

Comments
 (0)