Skip to content

Commit 7630ca0

Browse files
committed
[X86] AVX512 optimised CTLZ/CTTZ implementations for i256/i512 scalars
Make use of AVX512 VPLZCNT/VPOPCNT to perform the big integer bit counts per vector element and then use VPCOMPRESS to extract the first non-zero element result
1 parent 72616c5 commit 7630ca0

File tree

2 files changed

+210
-210
lines changed

2 files changed

+210
-210
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2652,6 +2652,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
26522652
ISD::AVGCEILU,
26532653
ISD::AVGFLOORS,
26542654
ISD::AVGFLOORU,
2655+
ISD::CTLZ,
2656+
ISD::CTTZ,
26552657
ISD::BITREVERSE,
26562658
ISD::ADD,
26572659
ISD::FADD,
@@ -55130,6 +55132,55 @@ static SDValue combineXor(SDNode *N, SelectionDAG &DAG,
5513055132
return combineFneg(N, DAG, DCI, Subtarget);
5513155133
}
5513255134

55135+
// Fold i256/i512 CTLZ/CTTZ patterns to make use of AVX512
55136+
// vXi64 CTLZ/CTTZ and VECTOR_COMPRESS.
55137+
// Compute the CTLZ/CTTZ of each element, add the element's bit offset, compress
55138+
// the result to remove all zero elements (passthru is set to scalar bitwidth if
55139+
// all elements are zero) and extract the lowest compressed element.
55140+
static SDValue combineCTZ(SDNode *N, SelectionDAG &DAG,
55141+
TargetLowering::DAGCombinerInfo &DCI,
55142+
const X86Subtarget &Subtarget) {
55143+
EVT VT = N->getValueType(0);
55144+
SDValue N0 = N->getOperand(0);
55145+
unsigned Opc = N->getOpcode();
55146+
unsigned SizeInBits = VT.getSizeInBits();
55147+
assert((Opc == ISD::CTLZ || Opc == ISD::CTTZ) && "Unsupported bit count");
55148+
55149+
if (VT.isScalarInteger() && Subtarget.hasCDI() &&
55150+
((SizeInBits == 512 && Subtarget.useAVX512Regs()) ||
55151+
(SizeInBits == 256 && Subtarget.hasVLX() &&
55152+
X86::mayFoldLoad(N0, Subtarget)))) {
55153+
MVT VecVT = MVT::getVectorVT(MVT::i64, SizeInBits / 64);
55154+
MVT BoolVT = VecVT.changeVectorElementType(MVT::i1);
55155+
SDValue Vec = DAG.getBitcast(VecVT, N0);
55156+
SDLoc DL(N0);
55157+
55158+
SmallVector<int, 8> RevMask;
55159+
SmallVector<SDValue, 8> Offsets;
55160+
for (unsigned I = 0, E = VecVT.getVectorNumElements(); I != E; ++I) {
55161+
RevMask.push_back((int)((E - 1) - I));
55162+
Offsets.push_back(DAG.getConstant(I * 64, DL, MVT::i64));
55163+
}
55164+
55165+
// CTLZ - reverse the elements as we want the top non-zero element.
55166+
if (Opc == ISD::CTLZ)
55167+
Vec = DAG.getVectorShuffle(VecVT, DL, Vec, Vec, RevMask);
55168+
55169+
SDValue IsNonZero = DAG.getSetCC(DL, BoolVT, Vec,
55170+
DAG.getConstant(0, DL, VecVT), ISD::SETNE);
55171+
SDValue Cnt = DAG.getNode(Opc, DL, VecVT, Vec);
55172+
Cnt = DAG.getNode(ISD::ADD, DL, VecVT, Cnt,
55173+
DAG.getBuildVector(VecVT, DL, Offsets));
55174+
Cnt = DAG.getNode(ISD::VECTOR_COMPRESS, DL, VecVT, Cnt, IsNonZero,
55175+
DAG.getConstant(SizeInBits, DL, VecVT));
55176+
Cnt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cnt,
55177+
DAG.getVectorIdxConstant(0, DL));
55178+
return DAG.getZExtOrTrunc(Cnt, DL, VT);
55179+
}
55180+
55181+
return SDValue();
55182+
}
55183+
5513355184
static SDValue combineBITREVERSE(SDNode *N, SelectionDAG &DAG,
5513455185
TargetLowering::DAGCombinerInfo &DCI,
5513555186
const X86Subtarget &Subtarget) {
@@ -60804,6 +60855,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
6080460855
case ISD::AND: return combineAnd(N, DAG, DCI, Subtarget);
6080560856
case ISD::OR: return combineOr(N, DAG, DCI, Subtarget);
6080660857
case ISD::XOR: return combineXor(N, DAG, DCI, Subtarget);
60858+
case ISD::CTLZ:
60859+
case ISD::CTTZ: return combineCTZ(N, DAG, DCI, Subtarget);
6080760860
case ISD::BITREVERSE: return combineBITREVERSE(N, DAG, DCI, Subtarget);
6080860861
case ISD::AVGCEILS:
6080960862
case ISD::AVGCEILU:

0 commit comments

Comments
 (0)