@@ -2654,6 +2654,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
26542654 ISD::AVGCEILU,
26552655 ISD::AVGFLOORS,
26562656 ISD::AVGFLOORU,
2657+ ISD::CTLZ,
2658+ ISD::CTTZ,
26572659 ISD::BITREVERSE,
26582660 ISD::ADD,
26592661 ISD::FADD,
@@ -55162,6 +55164,55 @@ static SDValue combineXor(SDNode *N, SelectionDAG &DAG,
5516255164 return combineFneg(N, DAG, DCI, Subtarget);
5516355165}
5516455166
55167+ // Fold i256/i512 CTLZ/CTTZ patterns to make use of AVX512
55168+ // vXi64 CTLZ/CTTZ and VECTOR_COMPRESS.
55169+ // Compute the CTLZ/CTTZ of each element, add the element's bit offset, compress
55170+ // the result to remove all zero elements (passthru is set to scalar bitwidth if
55171+ // all elements are zero) and extract the lowest compressed element.
55172+ static SDValue combineCTZ(SDNode *N, SelectionDAG &DAG,
55173+ TargetLowering::DAGCombinerInfo &DCI,
55174+ const X86Subtarget &Subtarget) {
55175+ EVT VT = N->getValueType(0);
55176+ SDValue N0 = N->getOperand(0);
55177+ unsigned Opc = N->getOpcode();
55178+ unsigned SizeInBits = VT.getSizeInBits();
55179+ assert((Opc == ISD::CTLZ || Opc == ISD::CTTZ) && "Unsupported bit count");
55180+
55181+ if (VT.isScalarInteger() && Subtarget.hasCDI() &&
55182+ ((SizeInBits == 512 && Subtarget.useAVX512Regs()) ||
55183+ (SizeInBits == 256 && Subtarget.hasVLX() &&
55184+ X86::mayFoldLoad(N0, Subtarget)))) {
55185+ MVT VecVT = MVT::getVectorVT(MVT::i64, SizeInBits / 64);
55186+ MVT BoolVT = VecVT.changeVectorElementType(MVT::i1);
55187+ SDValue Vec = DAG.getBitcast(VecVT, N0);
55188+ SDLoc DL(N0);
55189+
55190+ SmallVector<int, 8> RevMask;
55191+ SmallVector<SDValue, 8> Offsets;
55192+ for (unsigned I = 0, E = VecVT.getVectorNumElements(); I != E; ++I) {
55193+ RevMask.push_back((int)((E - 1) - I));
55194+ Offsets.push_back(DAG.getConstant(I * 64, DL, MVT::i64));
55195+ }
55196+
55197+ // CTLZ - reverse the elements as we want the top non-zero element.
55198+ if (Opc == ISD::CTLZ)
55199+ Vec = DAG.getVectorShuffle(VecVT, DL, Vec, Vec, RevMask);
55200+
55201+ SDValue IsNonZero = DAG.getSetCC(DL, BoolVT, Vec,
55202+ DAG.getConstant(0, DL, VecVT), ISD::SETNE);
55203+ SDValue Cnt = DAG.getNode(Opc, DL, VecVT, Vec);
55204+ Cnt = DAG.getNode(ISD::ADD, DL, VecVT, Cnt,
55205+ DAG.getBuildVector(VecVT, DL, Offsets));
55206+ Cnt = DAG.getNode(ISD::VECTOR_COMPRESS, DL, VecVT, Cnt, IsNonZero,
55207+ DAG.getConstant(SizeInBits, DL, VecVT));
55208+ Cnt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cnt,
55209+ DAG.getVectorIdxConstant(0, DL));
55210+ return DAG.getZExtOrTrunc(Cnt, DL, VT);
55211+ }
55212+
55213+ return SDValue();
55214+ }
55215+
5516555216static SDValue combineBITREVERSE(SDNode *N, SelectionDAG &DAG,
5516655217 TargetLowering::DAGCombinerInfo &DCI,
5516755218 const X86Subtarget &Subtarget) {
@@ -60885,6 +60936,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
6088560936 case ISD::AND: return combineAnd(N, DAG, DCI, Subtarget);
6088660937 case ISD::OR: return combineOr(N, DAG, DCI, Subtarget);
6088760938 case ISD::XOR: return combineXor(N, DAG, DCI, Subtarget);
60939+ case ISD::CTLZ:
60940+ case ISD::CTTZ: return combineCTZ(N, DAG, DCI, Subtarget);
6088860941 case ISD::BITREVERSE: return combineBITREVERSE(N, DAG, DCI, Subtarget);
6088960942 case ISD::AVGCEILS:
6089060943 case ISD::AVGCEILU:
0 commit comments