Skip to content

Commit 007fc23

Browse files
committed
[X86][FP16][BF16] Improve vectorization of fcmp
1 parent e9aee4f commit 007fc23

File tree

10 files changed

+161
-686
lines changed

10 files changed

+161
-686
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1711,6 +1711,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
17111711
setOperationPromotedToType(Opc, MVT::v8f16, MVT::v8f32);
17121712
setOperationPromotedToType(Opc, MVT::v16f16, MVT::v16f32);
17131713
}
1714+
setOperationAction(ISD::SETCC, MVT::v8f16, Custom);
1715+
setOperationAction(ISD::SETCC, MVT::v16f16, Custom);
17141716
}
17151717

17161718
// This block controls legalization of the mask vector sizes that are
@@ -2046,6 +2048,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
20462048
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v16f32, Custom);
20472049
for (unsigned Opc : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV})
20482050
setOperationPromotedToType(Opc, MVT::v32f16, MVT::v32f32);
2051+
setOperationAction(ISD::SETCC, MVT::v32f16, Custom);
20492052

20502053
for (auto VT : { MVT::v16i32, MVT::v8i64, MVT::v16f32, MVT::v8f64 }) {
20512054
setOperationAction(ISD::MLOAD, VT, Legal);
@@ -2401,6 +2404,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
24012404
setOperationPromotedToType(Opc, MVT::v8bf16, MVT::v8f32);
24022405
setOperationPromotedToType(Opc, MVT::v16bf16, MVT::v16f32);
24032406
}
2407+
setOperationAction(ISD::SETCC, MVT::v8bf16, Custom);
2408+
setOperationAction(ISD::SETCC, MVT::v16bf16, Custom);
24042409
setOperationAction(ISD::FP_ROUND, MVT::v8bf16, Custom);
24052410
addLegalFPImmediate(APFloat::getZero(APFloat::BFloat()));
24062411
}
@@ -2411,6 +2416,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
24112416
setF16Action(MVT::v32bf16, Expand);
24122417
for (unsigned Opc : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV})
24132418
setOperationPromotedToType(Opc, MVT::v32bf16, MVT::v32f32);
2419+
setOperationAction(ISD::SETCC, MVT::v32bf16, Custom);
24142420
setOperationAction(ISD::BUILD_VECTOR, MVT::v32bf16, Custom);
24152421
setOperationAction(ISD::FP_ROUND, MVT::v16bf16, Custom);
24162422
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v32bf16, Custom);
@@ -23333,12 +23339,8 @@ static unsigned translateX86FSETCC(ISD::CondCode SetCCOpcode, SDValue &Op0,
2333323339

2333423340
/// Break a VSETCC 256-bit integer VSETCC into two new 128 ones and then
2333523341
/// concatenate the result back.
23336-
static SDValue splitIntVSETCC(EVT VT, SDValue LHS, SDValue RHS,
23337-
ISD::CondCode Cond, SelectionDAG &DAG,
23338-
const SDLoc &dl) {
23339-
assert(VT.isInteger() && VT == LHS.getValueType() &&
23340-
VT == RHS.getValueType() && "Unsupported VTs!");
23341-
23342+
static SDValue splitVSETCC(EVT VT, SDValue LHS, SDValue RHS, ISD::CondCode Cond,
23343+
SelectionDAG &DAG, const SDLoc &dl) {
2334223344
SDValue CC = DAG.getCondCode(Cond);
2334323345

2334423346
// Extract the LHS Lo/Hi vectors
@@ -23483,14 +23485,40 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
2348323485
SDLoc dl(Op);
2348423486

2348523487
if (isFP) {
23486-
MVT EltVT = Op0.getSimpleValueType().getVectorElementType();
23488+
MVT SVT = Op0.getSimpleValueType();
23489+
MVT EltVT = SVT.getVectorElementType();
2348723490
assert(EltVT == MVT::bf16 || EltVT == MVT::f16 || EltVT == MVT::f32 ||
2348823491
EltVT == MVT::f64);
23489-
if (isSoftF16(EltVT, Subtarget))
23490-
return SDValue();
2349123492

23492-
bool IsSignaling = Op.getOpcode() == ISD::STRICT_FSETCCS;
2349323493
SDValue Chain = IsStrict ? Op.getOperand(0) : SDValue();
23494+
if (isSoftF16(EltVT, Subtarget)) {
23495+
// Break 256-bit FP vector compare into smaller ones.
23496+
if (SVT.is256BitVector() && !Subtarget.useAVX512Regs())
23497+
return splitVSETCC(VT, Op0, Op1, Cond, DAG, dl);
23498+
23499+
// Break 512-bit FP vector compare into smaller ones.
23500+
if (SVT.is512BitVector())
23501+
return splitVSETCC(VT, Op0, Op1, Cond, DAG, dl);
23502+
23503+
MVT NVT = SVT.changeVectorElementType(MVT::f32);
23504+
if (IsStrict) {
23505+
Op0 = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NVT, MVT::Other},
23506+
{Chain, Op0});
23507+
Op1 = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NVT, MVT::Other},
23508+
{Chain, Op1});
23509+
return DAG.getNode(Op.getOpcode(), dl, {VT, MVT::Other},
23510+
{Chain, Op0, Op1, CC});
23511+
}
23512+
MVT DVT = VT.getVectorElementType() == MVT::i16
23513+
? VT.changeVectorElementType(MVT::i32)
23514+
: VT;
23515+
SDValue Cmp = DAG.getNode(Op.getOpcode(), dl, DVT,
23516+
DAG.getNode(ISD::FP_EXTEND, dl, NVT, Op0),
23517+
DAG.getNode(ISD::FP_EXTEND, dl, NVT, Op1), CC);
23518+
return DVT == VT ? Cmp : DAG.getNode(ISD::TRUNCATE, dl, VT, Cmp);
23519+
}
23520+
23521+
bool IsSignaling = Op.getOpcode() == ISD::STRICT_FSETCCS;
2349423522

2349523523
// If we have a strict compare with a vXi1 result and the input is 128/256
2349623524
// bits we can't use a masked compare unless we have VLX. If we use a wider
@@ -23701,12 +23729,12 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
2370123729

2370223730
// Break 256-bit integer vector compare into smaller ones.
2370323731
if (VT.is256BitVector() && !Subtarget.hasInt256())
23704-
return splitIntVSETCC(VT, Op0, Op1, Cond, DAG, dl);
23732+
return splitVSETCC(VT, Op0, Op1, Cond, DAG, dl);
2370523733

2370623734
// Break 512-bit integer vector compare into smaller ones.
2370723735
// TODO: Try harder to use VPCMPx + VPMOV2x?
2370823736
if (VT.is512BitVector())
23709-
return splitIntVSETCC(VT, Op0, Op1, Cond, DAG, dl);
23737+
return splitVSETCC(VT, Op0, Op1, Cond, DAG, dl);
2371023738

2371123739
// If we have a limit constant, try to form PCMPGT (signed cmp) to avoid
2371223740
// not-of-PCMPEQ:

0 commit comments

Comments
 (0)