Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2654,6 +2654,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
ISD::AVGCEILU,
ISD::AVGFLOORS,
ISD::AVGFLOORU,
ISD::CTLZ,
ISD::CTTZ,
ISD::CTLZ_ZERO_UNDEF,
ISD::CTTZ_ZERO_UNDEF,
ISD::BITREVERSE,
ISD::ADD,
ISD::FADD,
Expand Down Expand Up @@ -55162,6 +55166,65 @@ static SDValue combineXor(SDNode *N, SelectionDAG &DAG,
return combineFneg(N, DAG, DCI, Subtarget);
}

// Fold i256/i512 CTLZ/CTTZ patterns to make use of AVX512
// vXi64 CTLZ/CTTZ and VECTOR_COMPRESS.
// Compute the CTLZ/CTTZ of each element, add the element's bit offset, compress
// the result to remove all zero elements (passthru is set to scalar bitwidth if
// all elements are zero) and extract the lowest compressed element.
static SDValue combineCTZ(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
EVT VT = N->getValueType(0);
SDValue N0 = N->getOperand(0);
unsigned Opc = N->getOpcode();
unsigned SizeInBits = VT.getSizeInBits();
assert((Opc == ISD::CTLZ || Opc == ISD::CTLZ_ZERO_UNDEF || Opc == ISD::CTTZ ||
Opc == ISD::CTTZ_ZERO_UNDEF) &&
"Unsupported bit count");

if (VT.isScalarInteger() && Subtarget.hasCDI() &&
((SizeInBits == 512 && Subtarget.useAVX512Regs()) ||
(SizeInBits == 256 && Subtarget.hasVLX() &&
X86::mayFoldLoad(N0, Subtarget)))) {
MVT VecVT = MVT::getVectorVT(MVT::i64, SizeInBits / 64);
MVT BoolVT = VecVT.changeVectorElementType(MVT::i1);
SDValue Vec = DAG.getBitcast(VecVT, N0);
SDLoc DL(N);

SmallVector<int, 8> RevMask;
SmallVector<SDValue, 8> Offsets;
for (unsigned I = 0, E = VecVT.getVectorNumElements(); I != E; ++I) {
RevMask.push_back((int)((E - 1) - I));
Offsets.push_back(DAG.getConstant(I * 64, DL, MVT::i64));
}

// CTLZ - reverse the elements as we want the top non-zero element at the
// bottom for compression.
unsigned VecOpc = ISD::CTTZ;
if (Opc == ISD::CTLZ || Opc == ISD::CTLZ_ZERO_UNDEF) {
VecOpc = ISD::CTLZ;
Vec = DAG.getVectorShuffle(VecVT, DL, Vec, Vec, RevMask);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be 64 instead of SizeInBits?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No - I've added the offsets to each element at this point, the pass through is for the compress and will only appear in elt[0] if the entire vector is zero - in which case it should return the full scalar integer width (256/512).


SDValue PassThrough = DAG.getUNDEF(VecVT);
if (Opc == ISD::CTLZ || Opc == ISD::CTTZ)
PassThrough = DAG.getConstant(SizeInBits, DL, VecVT);

SDValue IsNonZero = DAG.getSetCC(DL, BoolVT, Vec,
DAG.getConstant(0, DL, VecVT), ISD::SETNE);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we missing Cnt for IsNonZero?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch - yes, we mustn't use the ZERO_UNDEF variants on the vector op.

SDValue Cnt = DAG.getNode(VecOpc, DL, VecVT, Vec);
Cnt = DAG.getNode(ISD::ADD, DL, VecVT, Cnt,
DAG.getBuildVector(VecVT, DL, Offsets));
Cnt = DAG.getNode(ISD::VECTOR_COMPRESS, DL, VecVT, Cnt, IsNonZero,
PassThrough);
Cnt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cnt,
DAG.getVectorIdxConstant(0, DL));
return DAG.getZExtOrTrunc(Cnt, DL, VT);
}

return SDValue();
}

static SDValue combineBITREVERSE(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
Expand Down Expand Up @@ -60885,6 +60948,10 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::AND: return combineAnd(N, DAG, DCI, Subtarget);
case ISD::OR: return combineOr(N, DAG, DCI, Subtarget);
case ISD::XOR: return combineXor(N, DAG, DCI, Subtarget);
case ISD::CTLZ:
case ISD::CTTZ:
case ISD::CTLZ_ZERO_UNDEF:
case ISD::CTTZ_ZERO_UNDEF:return combineCTZ(N, DAG, DCI, Subtarget);
case ISD::BITREVERSE: return combineBITREVERSE(N, DAG, DCI, Subtarget);
case ISD::AVGCEILS:
case ISD::AVGCEILU:
Expand Down
Loading
Loading