Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 157 additions & 61 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,20 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::BITREVERSE, VT, Custom);
}

// Vector min/max reductions
if (Subtarget.hasSSE41())
{
for (MVT VT : MVT::vector_valuetypes()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SSE should only be declaring v816/v16i8 custom if we can.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

X86TTIImpl::shouldExpandReduction does a more thorough check. If we can't lower them properly they will get expanded to shuffles and we will never see these reductions in the DAG.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds like the ExpandReductions pass isn't accounting for type legalisation properly?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I see what you mean. I added checks here to only do lowering for vector types that support PHMINPOSUW.

if (VT.getScalarType() == MVT::i8 || VT.getScalarType() == MVT::i16)
{
setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
}
}
}

if (!Subtarget.useSoftFloat() && Subtarget.hasAVX()) {
bool HasInt256 = Subtarget.hasInt256();

Expand Down Expand Up @@ -25409,6 +25423,94 @@ static SDValue LowerEXTEND_VECTOR_INREG(SDValue Op,
return SignExt;
}

// Create a min/max v8i16/v16i8 horizontal reduction with PHMINPOSUW.
static SDValue createMinMaxReduction(SDValue Src, EVT TargetVT, SDLoc DL,
ISD::NodeType BinOp, SelectionDAG &DAG,
const X86Subtarget &Subtarget)
{
assert(Subtarget.hasSSE41() && "The caller must check if SSE4.1 is available");

EVT SrcVT = Src.getValueType();
EVT SrcSVT = SrcVT.getScalarType();

if (SrcSVT != TargetVT || (SrcVT.getSizeInBits() % 128) != 0)
return SDValue();

// First, reduce the source down to 128-bit, applying BinOp to lo/hi.
while (SrcVT.getSizeInBits() > 128) {
SDValue Lo, Hi;
std::tie(Lo, Hi) = splitVector(Src, DAG, DL);
SrcVT = Lo.getValueType();
Src = DAG.getNode(BinOp, DL, SrcVT, Lo, Hi);
}
assert(((SrcVT == MVT::v8i16 && TargetVT == MVT::i16) ||
(SrcVT == MVT::v16i8 && TargetVT == MVT::i8)) &&
"Unexpected value type");

// PHMINPOSUW applies to UMIN(v8i16), for SMIN/SMAX/UMAX we must apply a mask
// to flip the value accordingly.
SDValue Mask;
unsigned MaskEltsBits = TargetVT.getSizeInBits();
if (BinOp == ISD::SMAX)
Mask = DAG.getConstant(APInt::getSignedMaxValue(MaskEltsBits), DL, SrcVT);
else if (BinOp == ISD::SMIN)
Mask = DAG.getConstant(APInt::getSignedMinValue(MaskEltsBits), DL, SrcVT);
else if (BinOp == ISD::UMAX)
Mask = DAG.getAllOnesConstant(DL, SrcVT);

if (Mask)
Src = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, Src);

// For v16i8 cases we need to perform UMIN on pairs of byte elements,
// shuffling each upper element down and insert zeros. This means that the
// v16i8 UMIN will leave the upper element as zero, performing zero-extension
// ready for the PHMINPOS.
if (TargetVT == MVT::i8) {
SDValue Upper = DAG.getVectorShuffle(
SrcVT, DL, Src, DAG.getConstant(0, DL, MVT::v16i8),
{1, 16, 3, 16, 5, 16, 7, 16, 9, 16, 11, 16, 13, 16, 15, 16});
Src = DAG.getNode(ISD::UMIN, DL, SrcVT, Src, Upper);
}

// Perform the PHMINPOS on a v8i16 vector,
Src = DAG.getBitcast(MVT::v8i16, Src);
Src = DAG.getNode(X86ISD::PHMINPOS, DL, MVT::v8i16, Src);
Src = DAG.getBitcast(SrcVT, Src);

if (Mask)
Src = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, Src);

return DAG.getExtractVectorElt(DL, TargetVT, Src, 0);
}

static SDValue LowerVECTOR_REDUCE_MINMAX(SDValue Op,
const X86Subtarget& Subtarget,
SelectionDAG& DAG)
{
ISD::NodeType BinOp;
switch (Op.getOpcode())
{
default:
assert(false && "Expected min/max reduction");
break;
case ISD::VECREDUCE_UMIN:
BinOp = ISD::UMIN;
break;
case ISD::VECREDUCE_UMAX:
BinOp = ISD::UMAX;
break;
case ISD::VECREDUCE_SMIN:
BinOp = ISD::SMIN;
break;
case ISD::VECREDUCE_SMAX:
BinOp = ISD::SMAX;
break;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace this with ISD::NodeType BinOp = ISD::getVecReduceBaseOpcode(Op.getOpcode()) - you can then assert that BinOp is a min/max opcode if you want.


return createMinMaxReduction(Op->getOperand(0), Op.getValueType(), SDLoc(Op),
BinOp, DAG, Subtarget);
}

static SDValue LowerSIGN_EXTEND(SDValue Op, const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
MVT VT = Op->getSimpleValueType(0);
Expand Down Expand Up @@ -33620,6 +33722,11 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::ZERO_EXTEND_VECTOR_INREG:
case ISD::SIGN_EXTEND_VECTOR_INREG:
return LowerEXTEND_VECTOR_INREG(Op, Subtarget, DAG);
case ISD::VECREDUCE_UMIN:
case ISD::VECREDUCE_UMAX:
case ISD::VECREDUCE_SMIN:
case ISD::VECREDUCE_SMAX:
return LowerVECTOR_REDUCE_MINMAX(Op, Subtarget, DAG);
case ISD::FP_TO_SINT:
case ISD::STRICT_FP_TO_SINT:
case ISD::FP_TO_UINT:
Expand Down Expand Up @@ -46192,60 +46299,8 @@ static SDValue combineMinMaxReduction(SDNode *Extract, SelectionDAG &DAG,
if (!Src)
return SDValue();

EVT SrcVT = Src.getValueType();
EVT SrcSVT = SrcVT.getScalarType();
if (SrcSVT != ExtractVT || (SrcVT.getSizeInBits() % 128) != 0)
return SDValue();

SDLoc DL(Extract);
SDValue MinPos = Src;

// First, reduce the source down to 128-bit, applying BinOp to lo/hi.
while (SrcVT.getSizeInBits() > 128) {
SDValue Lo, Hi;
std::tie(Lo, Hi) = splitVector(MinPos, DAG, DL);
SrcVT = Lo.getValueType();
MinPos = DAG.getNode(BinOp, DL, SrcVT, Lo, Hi);
}
assert(((SrcVT == MVT::v8i16 && ExtractVT == MVT::i16) ||
(SrcVT == MVT::v16i8 && ExtractVT == MVT::i8)) &&
"Unexpected value type");

// PHMINPOSUW applies to UMIN(v8i16), for SMIN/SMAX/UMAX we must apply a mask
// to flip the value accordingly.
SDValue Mask;
unsigned MaskEltsBits = ExtractVT.getSizeInBits();
if (BinOp == ISD::SMAX)
Mask = DAG.getConstant(APInt::getSignedMaxValue(MaskEltsBits), DL, SrcVT);
else if (BinOp == ISD::SMIN)
Mask = DAG.getConstant(APInt::getSignedMinValue(MaskEltsBits), DL, SrcVT);
else if (BinOp == ISD::UMAX)
Mask = DAG.getAllOnesConstant(DL, SrcVT);

if (Mask)
MinPos = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, MinPos);

// For v16i8 cases we need to perform UMIN on pairs of byte elements,
// shuffling each upper element down and insert zeros. This means that the
// v16i8 UMIN will leave the upper element as zero, performing zero-extension
// ready for the PHMINPOS.
if (ExtractVT == MVT::i8) {
SDValue Upper = DAG.getVectorShuffle(
SrcVT, DL, MinPos, DAG.getConstant(0, DL, MVT::v16i8),
{1, 16, 3, 16, 5, 16, 7, 16, 9, 16, 11, 16, 13, 16, 15, 16});
MinPos = DAG.getNode(ISD::UMIN, DL, SrcVT, MinPos, Upper);
}

// Perform the PHMINPOS on a v8i16 vector,
MinPos = DAG.getBitcast(MVT::v8i16, MinPos);
MinPos = DAG.getNode(X86ISD::PHMINPOS, DL, MVT::v8i16, MinPos);
MinPos = DAG.getBitcast(SrcVT, MinPos);

if (Mask)
MinPos = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, MinPos);

return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtractVT, MinPos,
DAG.getVectorIdxConstant(0, DL));
return createMinMaxReduction(Src, ExtractVT, SDLoc(Extract),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still required?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have some CodeGen/X86 tests that match the raw shuffle chain pattern instead of the reduction intrinsic that fail without this - its probably another case of #143088 that we need to improve the middle-end to ensure those patterns get folded into reduction intrinsics so the backend doesn't need to handle it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've raised #144654

BinOp, DAG, Subtarget);
}

// Attempt to replace an all_of/any_of/parity style horizontal reduction with a MOVMSK.
Expand Down Expand Up @@ -47081,7 +47136,8 @@ static SDValue combineArithReduction(SDNode *ExtElt, SelectionDAG &DAG,
/// scalars back, while for x64 we should use 64-bit extracts and shifts.
static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
const X86Subtarget &Subtarget,
bool& TransformedBinOpReduction) {
if (SDValue NewOp = combineExtractWithShuffle(N, DAG, DCI, Subtarget))
return NewOp;

Expand Down Expand Up @@ -47169,23 +47225,33 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
// Check whether this extract is the root of a sum of absolute differences
// pattern. This has to be done here because we really want it to happen
// pre-legalization,
if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget))
if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget)) {
TransformedBinOpReduction = true;
return SAD;
}

if (SDValue VPDPBUSD = combineVPDPBUSDPattern(N, DAG, Subtarget))
if (SDValue VPDPBUSD = combineVPDPBUSDPattern(N, DAG, Subtarget)) {
TransformedBinOpReduction = true;
return VPDPBUSD;
}

// Attempt to replace an all_of/any_of horizontal reduction with a MOVMSK.
if (SDValue Cmp = combinePredicateReduction(N, DAG, Subtarget))
if (SDValue Cmp = combinePredicateReduction(N, DAG, Subtarget)) {
TransformedBinOpReduction = true;
return Cmp;
}

// Attempt to replace min/max v8i16/v16i8 reductions with PHMINPOSUW.
if (SDValue MinMax = combineMinMaxReduction(N, DAG, Subtarget))
if (SDValue MinMax = combineMinMaxReduction(N, DAG, Subtarget)) {
TransformedBinOpReduction = true;
return MinMax;
}

// Attempt to optimize ADD/FADD/MUL reductions with HADD, promotion etc..
if (SDValue V = combineArithReduction(N, DAG, Subtarget))
if (SDValue V = combineArithReduction(N, DAG, Subtarget)) {
TransformedBinOpReduction = true;
return V;
}

if (SDValue V = scalarizeExtEltFP(N, DAG, Subtarget, DCI))
return V;
Expand Down Expand Up @@ -47255,6 +47321,36 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

static SDValue combineExtractVectorEltAndOperand(SDNode* N, SelectionDAG& DAG,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the combine independent of moving code to the lowering? Better to split them if it is.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are independent, but in the end they serve the same purpose. Should I open a separate PR for the lowering?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. That would be good to show which code affect the test change.

TargetLowering::DAGCombinerInfo& DCI,
const X86Subtarget& Subtarget)
{
bool TransformedBinOpReduction = false;
auto Op = combineExtractVectorElt(N, DAG, DCI, Subtarget, TransformedBinOpReduction);

if (TransformedBinOpReduction)
{
// In case we simplified N = extract_vector_element(V, 0) with Op and V
// resulted from a reduction, then we need to replace all uses of V with
// scalar_to_vector(Op) to make sure that we eliminated the binop + shuffle
// pyramid. This is safe to do, because the elements of V are undefined except
// for the zeroth element and Op does not depend on V.

auto OldV = N->getOperand(0);
assert(!Op.getNode()->hasPredecessor(OldV.getNode()) &&
"Op must not depend on the converted reduction");

auto NewV = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), OldV->getValueType(0), Op);

auto NV = DCI.CombineTo(N, Op);
DCI.CombineTo(OldV.getNode(), NewV);

Op = NV; // Return N so it doesn't get rechecked!
}

return Op;
}

// Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)).
// This is more or less the reverse of combineBitcastvxi1.
static SDValue combineToExtendBoolVectorInReg(
Expand Down Expand Up @@ -60702,7 +60798,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::EXTRACT_VECTOR_ELT:
case X86ISD::PEXTRW:
case X86ISD::PEXTRB:
return combineExtractVectorElt(N, DAG, DCI, Subtarget);
return combineExtractVectorEltAndOperand(N, DAG, DCI, Subtarget);
case ISD::CONCAT_VECTORS:
return combineCONCAT_VECTORS(N, DAG, DCI, Subtarget);
case ISD::INSERT_SUBVECTOR:
Expand Down
19 changes: 19 additions & 0 deletions llvm/lib/Target/X86/X86TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6575,6 +6575,25 @@ X86TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
return Options;
}

bool llvm::X86TTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
switch (II->getIntrinsicID()) {
default:
return true;

case Intrinsic::vector_reduce_umin:
case Intrinsic::vector_reduce_umax:
case Intrinsic::vector_reduce_smin:
case Intrinsic::vector_reduce_smax:
auto *VType = cast<FixedVectorType>(II->getOperand(0)->getType());
auto SType = VType->getScalarType();
bool CanUsePHMINPOSUW =
ST->hasSSE41() && II->getType() == SType &&
(VType->getPrimitiveSizeInBits() % 128) == 0 &&
(SType->isIntegerTy(8) || SType->isIntegerTy(16));
return !CanUsePHMINPOSUW;
}
}

bool X86TTIImpl::prefersVectorizedAddressing() const {
return supportsGather();
}
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/X86/X86TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
TTI::MemCmpExpansionOptions
enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const override;
bool preferAlternateOpcodeVectorization() const override { return false; }
bool shouldExpandReduction(const IntrinsicInst *II) const override;
bool prefersVectorizedAddressing() const override;
bool supportsEfficientVectorElementLoadStore() const override;
bool enableInterleavedAccessVectorization() const override;
Expand Down
Loading
Loading