Skip to content

Commit 2ab84a6

Browse files
authored
[X86][FP16][BF16] Improve vectorization of fcmp (#116153)
1 parent c2bb056 commit 2ab84a6

File tree

10 files changed

+165
-689
lines changed

10 files changed

+165
-689
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1711,6 +1711,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
17111711
setOperationPromotedToType(Opc, MVT::v8f16, MVT::v8f32);
17121712
setOperationPromotedToType(Opc, MVT::v16f16, MVT::v16f32);
17131713
}
1714+
setOperationAction(ISD::SETCC, MVT::v8f16, Custom);
1715+
setOperationAction(ISD::SETCC, MVT::v16f16, Custom);
17141716
}
17151717

17161718
// This block controls legalization of the mask vector sizes that are
@@ -2046,6 +2048,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
20462048
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v16f32, Custom);
20472049
for (unsigned Opc : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV})
20482050
setOperationPromotedToType(Opc, MVT::v32f16, MVT::v32f32);
2051+
setOperationAction(ISD::SETCC, MVT::v32f16, Custom);
20492052

20502053
for (auto VT : { MVT::v16i32, MVT::v8i64, MVT::v16f32, MVT::v8f64 }) {
20512054
setOperationAction(ISD::MLOAD, VT, Legal);
@@ -2401,6 +2404,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
24012404
setOperationPromotedToType(Opc, MVT::v8bf16, MVT::v8f32);
24022405
setOperationPromotedToType(Opc, MVT::v16bf16, MVT::v16f32);
24032406
}
2407+
setOperationAction(ISD::SETCC, MVT::v8bf16, Custom);
2408+
setOperationAction(ISD::SETCC, MVT::v16bf16, Custom);
24042409
setOperationAction(ISD::FP_ROUND, MVT::v8bf16, Custom);
24052410
addLegalFPImmediate(APFloat::getZero(APFloat::BFloat()));
24062411
}
@@ -2411,6 +2416,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
24112416
setF16Action(MVT::v32bf16, Expand);
24122417
for (unsigned Opc : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV})
24132418
setOperationPromotedToType(Opc, MVT::v32bf16, MVT::v32f32);
2419+
setOperationAction(ISD::SETCC, MVT::v32bf16, Custom);
24142420
setOperationAction(ISD::BUILD_VECTOR, MVT::v32bf16, Custom);
24152421
setOperationAction(ISD::FP_ROUND, MVT::v16bf16, Custom);
24162422
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v32bf16, Custom);
@@ -23396,14 +23402,12 @@ static unsigned translateX86FSETCC(ISD::CondCode SetCCOpcode, SDValue &Op0,
2339623402
return SSECC;
2339723403
}
2339823404

23399-
/// Break a VSETCC 256-bit integer VSETCC into two new 128 ones and then
23405+
/// Break a VSETCC 256/512-bit vector into two new 128/256 ones and then
2340023406
/// concatenate the result back.
23401-
static SDValue splitIntVSETCC(EVT VT, SDValue LHS, SDValue RHS,
23402-
ISD::CondCode Cond, SelectionDAG &DAG,
23403-
const SDLoc &dl) {
23404-
assert(VT.isInteger() && VT == LHS.getValueType() &&
23405-
VT == RHS.getValueType() && "Unsupported VTs!");
23406-
23407+
static SDValue splitVSETCC(EVT VT, SDValue LHS, SDValue RHS, ISD::CondCode Cond,
23408+
SelectionDAG &DAG, const SDLoc &dl) {
23409+
assert(VT.isInteger() && LHS.getValueType() == RHS.getValueType() &&
23410+
"Unsupported VTs!");
2340723411
SDValue CC = DAG.getCondCode(Cond);
2340823412

2340923413
// Extract the LHS Lo/Hi vectors
@@ -23544,18 +23548,43 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
2354423548
SDValue CC = Op.getOperand(IsStrict ? 3 : 2);
2354523549
MVT VT = Op->getSimpleValueType(0);
2354623550
ISD::CondCode Cond = cast<CondCodeSDNode>(CC)->get();
23547-
bool isFP = Op1.getSimpleValueType().isFloatingPoint();
23551+
MVT OpVT = Op0.getSimpleValueType();
2354823552
SDLoc dl(Op);
2354923553

23550-
if (isFP) {
23551-
MVT EltVT = Op0.getSimpleValueType().getVectorElementType();
23554+
if (OpVT.isFloatingPoint()) {
23555+
MVT EltVT = OpVT.getVectorElementType();
2355223556
assert(EltVT == MVT::bf16 || EltVT == MVT::f16 || EltVT == MVT::f32 ||
2355323557
EltVT == MVT::f64);
23554-
if (isSoftF16(EltVT, Subtarget))
23555-
return SDValue();
2355623558

23557-
bool IsSignaling = Op.getOpcode() == ISD::STRICT_FSETCCS;
2355823559
SDValue Chain = IsStrict ? Op.getOperand(0) : SDValue();
23560+
if (isSoftF16(EltVT, Subtarget)) {
23561+
// Break 256-bit FP vector compare into smaller ones.
23562+
if (OpVT.is256BitVector() && !Subtarget.useAVX512Regs())
23563+
return splitVSETCC(VT, Op0, Op1, Cond, DAG, dl);
23564+
23565+
// Break 512-bit FP vector compare into smaller ones.
23566+
if (OpVT.is512BitVector())
23567+
return splitVSETCC(VT, Op0, Op1, Cond, DAG, dl);
23568+
23569+
MVT NVT = OpVT.changeVectorElementType(MVT::f32);
23570+
if (IsStrict) {
23571+
Op0 = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NVT, MVT::Other},
23572+
{Chain, Op0});
23573+
Op1 = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NVT, MVT::Other},
23574+
{Chain, Op1});
23575+
return DAG.getNode(Op.getOpcode(), dl, {VT, MVT::Other},
23576+
{Chain, Op0, Op1, CC});
23577+
}
23578+
MVT DVT = VT.getVectorElementType() == MVT::i16
23579+
? VT.changeVectorElementType(MVT::i32)
23580+
: VT;
23581+
SDValue Cmp = DAG.getNode(Op.getOpcode(), dl, DVT,
23582+
DAG.getNode(ISD::FP_EXTEND, dl, NVT, Op0),
23583+
DAG.getNode(ISD::FP_EXTEND, dl, NVT, Op1), CC);
23584+
return DVT == VT ? Cmp : DAG.getNode(ISD::TRUNCATE, dl, VT, Cmp);
23585+
}
23586+
23587+
bool IsSignaling = Op.getOpcode() == ISD::STRICT_FSETCCS;
2355923588

2356023589
// If we have a strict compare with a vXi1 result and the input is 128/256
2356123590
// bits we can't use a masked compare unless we have VLX. If we use a wider
@@ -23766,12 +23795,12 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
2376623795

2376723796
// Break 256-bit integer vector compare into smaller ones.
2376823797
if (VT.is256BitVector() && !Subtarget.hasInt256())
23769-
return splitIntVSETCC(VT, Op0, Op1, Cond, DAG, dl);
23798+
return splitVSETCC(VT, Op0, Op1, Cond, DAG, dl);
2377023799

2377123800
// Break 512-bit integer vector compare into smaller ones.
2377223801
// TODO: Try harder to use VPCMPx + VPMOV2x?
2377323802
if (VT.is512BitVector())
23774-
return splitIntVSETCC(VT, Op0, Op1, Cond, DAG, dl);
23803+
return splitVSETCC(VT, Op0, Op1, Cond, DAG, dl);
2377523804

2377623805
// If we have a limit constant, try to form PCMPGT (signed cmp) to avoid
2377723806
// not-of-PCMPEQ:

0 commit comments

Comments
 (0)