@@ -2654,6 +2654,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
26542654 ISD::AVGCEILU,
26552655 ISD::AVGFLOORS,
26562656 ISD::AVGFLOORU,
2657+ ISD::CTLZ,
2658+ ISD::CTTZ,
2659+ ISD::CTLZ_ZERO_UNDEF,
2660+ ISD::CTTZ_ZERO_UNDEF,
26572661 ISD::BITREVERSE,
26582662 ISD::ADD,
26592663 ISD::FADD,
@@ -55162,6 +55166,61 @@ static SDValue combineXor(SDNode *N, SelectionDAG &DAG,
5516255166 return combineFneg(N, DAG, DCI, Subtarget);
5516355167}
5516455168
55169+ // Fold i256/i512 CTLZ/CTTZ patterns to make use of AVX512
55170+ // vXi64 CTLZ/CTTZ and VECTOR_COMPRESS.
55171+ // Compute the CTLZ/CTTZ of each element, add the element's bit offset, compress
55172+ // the result to remove all zero elements (passthru is set to scalar bitwidth if
55173+ // all elements are zero) and extract the lowest compressed element.
55174+ static SDValue combineCTZ(SDNode *N, SelectionDAG &DAG,
55175+ TargetLowering::DAGCombinerInfo &DCI,
55176+ const X86Subtarget &Subtarget) {
55177+ EVT VT = N->getValueType(0);
55178+ SDValue N0 = N->getOperand(0);
55179+ unsigned Opc = N->getOpcode();
55180+ unsigned SizeInBits = VT.getSizeInBits();
55181+ assert((Opc == ISD::CTLZ || Opc == ISD::CTLZ_ZERO_UNDEF || Opc == ISD::CTTZ ||
55182+ Opc == ISD::CTTZ_ZERO_UNDEF) &&
55183+ "Unsupported bit count");
55184+
55185+ if (VT.isScalarInteger() && Subtarget.hasCDI() &&
55186+ ((SizeInBits == 512 && Subtarget.useAVX512Regs()) ||
55187+ (SizeInBits == 256 && Subtarget.hasVLX() &&
55188+ X86::mayFoldLoad(N0, Subtarget)))) {
55189+ MVT VecVT = MVT::getVectorVT(MVT::i64, SizeInBits / 64);
55190+ MVT BoolVT = VecVT.changeVectorElementType(MVT::i1);
55191+ SDValue Vec = DAG.getBitcast(VecVT, N0);
55192+ SDLoc DL(N);
55193+
55194+ SmallVector<int, 8> RevMask;
55195+ SmallVector<SDValue, 8> Offsets;
55196+ for (unsigned I = 0, E = VecVT.getVectorNumElements(); I != E; ++I) {
55197+ RevMask.push_back((int)((E - 1) - I));
55198+ Offsets.push_back(DAG.getConstant(I * 64, DL, MVT::i64));
55199+ }
55200+
55201+ // CTLZ - reverse the elements as we want the top non-zero element.
55202+ if (Opc == ISD::CTLZ)
55203+ Vec = DAG.getVectorShuffle(VecVT, DL, Vec, Vec, RevMask);
55204+
55205+ SDValue PassThrough = DAG.getUNDEF(VecVT);
55206+ if (Opc == ISD::CTLZ || Opc == ISD::CTTZ)
55207+ PassThrough = DAG.getConstant(SizeInBits, DL, VecVT);
55208+
55209+ SDValue IsNonZero = DAG.getSetCC(DL, BoolVT, Vec,
55210+ DAG.getConstant(0, DL, VecVT), ISD::SETNE);
55211+ SDValue Cnt = DAG.getNode(Opc, DL, VecVT, Vec);
55212+ Cnt = DAG.getNode(ISD::ADD, DL, VecVT, Cnt,
55213+ DAG.getBuildVector(VecVT, DL, Offsets));
55214+ Cnt = DAG.getNode(ISD::VECTOR_COMPRESS, DL, VecVT, Cnt, IsNonZero,
55215+ PassThrough);
55216+ Cnt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cnt,
55217+ DAG.getVectorIdxConstant(0, DL));
55218+ return DAG.getZExtOrTrunc(Cnt, DL, VT);
55219+ }
55220+
55221+ return SDValue();
55222+ }
55223+
5516555224static SDValue combineBITREVERSE(SDNode *N, SelectionDAG &DAG,
5516655225 TargetLowering::DAGCombinerInfo &DCI,
5516755226 const X86Subtarget &Subtarget) {
@@ -60885,6 +60944,10 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
6088560944 case ISD::AND: return combineAnd(N, DAG, DCI, Subtarget);
6088660945 case ISD::OR: return combineOr(N, DAG, DCI, Subtarget);
6088760946 case ISD::XOR: return combineXor(N, DAG, DCI, Subtarget);
60947+ case ISD::CTLZ:
60948+ case ISD::CTTZ:
60949+ case ISD::CTLZ_ZERO_UNDEF:
60950+ case ISD::CTTZ_ZERO_UNDEF:return combineCTZ(N, DAG, DCI, Subtarget);
6088860951 case ISD::BITREVERSE: return combineBITREVERSE(N, DAG, DCI, Subtarget);
6088960952 case ISD::AVGCEILS:
6089060953 case ISD::AVGCEILU:
0 commit comments