@@ -1711,6 +1711,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
17111711 setOperationPromotedToType(Opc, MVT::v8f16, MVT::v8f32);
17121712 setOperationPromotedToType(Opc, MVT::v16f16, MVT::v16f32);
17131713 }
1714+ setOperationAction(ISD::SETCC, MVT::v8f16, Custom);
1715+ setOperationAction(ISD::SETCC, MVT::v16f16, Custom);
17141716 }
17151717
17161718 // This block controls legalization of the mask vector sizes that are
@@ -2046,6 +2048,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
20462048 setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v16f32, Custom);
20472049 for (unsigned Opc : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV})
20482050 setOperationPromotedToType(Opc, MVT::v32f16, MVT::v32f32);
2051+ setOperationAction(ISD::SETCC, MVT::v32f16, Custom);
20492052
20502053 for (auto VT : { MVT::v16i32, MVT::v8i64, MVT::v16f32, MVT::v8f64 }) {
20512054 setOperationAction(ISD::MLOAD, VT, Legal);
@@ -2401,6 +2404,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
24012404 setOperationPromotedToType(Opc, MVT::v8bf16, MVT::v8f32);
24022405 setOperationPromotedToType(Opc, MVT::v16bf16, MVT::v16f32);
24032406 }
2407+ setOperationAction(ISD::SETCC, MVT::v8bf16, Custom);
2408+ setOperationAction(ISD::SETCC, MVT::v16bf16, Custom);
24042409 setOperationAction(ISD::FP_ROUND, MVT::v8bf16, Custom);
24052410 addLegalFPImmediate(APFloat::getZero(APFloat::BFloat()));
24062411 }
@@ -2411,6 +2416,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
24112416 setF16Action(MVT::v32bf16, Expand);
24122417 for (unsigned Opc : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV})
24132418 setOperationPromotedToType(Opc, MVT::v32bf16, MVT::v32f32);
2419+ setOperationAction(ISD::SETCC, MVT::v32bf16, Custom);
24142420 setOperationAction(ISD::BUILD_VECTOR, MVT::v32bf16, Custom);
24152421 setOperationAction(ISD::FP_ROUND, MVT::v16bf16, Custom);
24162422 setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v32bf16, Custom);
@@ -23333,12 +23339,8 @@ static unsigned translateX86FSETCC(ISD::CondCode SetCCOpcode, SDValue &Op0,
2333323339
2333423340/// Break a VSETCC 256-bit integer VSETCC into two new 128 ones and then
2333523341/// concatenate the result back.
23336- static SDValue splitIntVSETCC(EVT VT, SDValue LHS, SDValue RHS,
23337- ISD::CondCode Cond, SelectionDAG &DAG,
23338- const SDLoc &dl) {
23339- assert(VT.isInteger() && VT == LHS.getValueType() &&
23340- VT == RHS.getValueType() && "Unsupported VTs!");
23341-
23342+ static SDValue splitVSETCC(EVT VT, SDValue LHS, SDValue RHS, ISD::CondCode Cond,
23343+ SelectionDAG &DAG, const SDLoc &dl) {
2334223344 SDValue CC = DAG.getCondCode(Cond);
2334323345
2334423346 // Extract the LHS Lo/Hi vectors
@@ -23483,14 +23485,40 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
2348323485 SDLoc dl(Op);
2348423486
2348523487 if (isFP) {
23486- MVT EltVT = Op0.getSimpleValueType().getVectorElementType();
23488+ MVT SVT = Op0.getSimpleValueType();
23489+ MVT EltVT = SVT.getVectorElementType();
2348723490 assert(EltVT == MVT::bf16 || EltVT == MVT::f16 || EltVT == MVT::f32 ||
2348823491 EltVT == MVT::f64);
23489- if (isSoftF16(EltVT, Subtarget))
23490- return SDValue();
2349123492
23492- bool IsSignaling = Op.getOpcode() == ISD::STRICT_FSETCCS;
2349323493 SDValue Chain = IsStrict ? Op.getOperand(0) : SDValue();
23494+ if (isSoftF16(EltVT, Subtarget)) {
23495+ // Break 256-bit FP vector compare into smaller ones.
23496+ if (SVT.is256BitVector() && !Subtarget.useAVX512Regs())
23497+ return splitVSETCC(VT, Op0, Op1, Cond, DAG, dl);
23498+
23499+ // Break 512-bit FP vector compare into smaller ones.
23500+ if (SVT.is512BitVector())
23501+ return splitVSETCC(VT, Op0, Op1, Cond, DAG, dl);
23502+
23503+ MVT NVT = SVT.changeVectorElementType(MVT::f32);
23504+ if (IsStrict) {
23505+ Op0 = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NVT, MVT::Other},
23506+ {Chain, Op0});
23507+ Op1 = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NVT, MVT::Other},
23508+ {Chain, Op1});
23509+ return DAG.getNode(Op.getOpcode(), dl, {VT, MVT::Other},
23510+ {Chain, Op0, Op1, CC});
23511+ }
23512+ MVT DVT = VT.getVectorElementType() == MVT::i16
23513+ ? VT.changeVectorElementType(MVT::i32)
23514+ : VT;
23515+ SDValue Cmp = DAG.getNode(Op.getOpcode(), dl, DVT,
23516+ DAG.getNode(ISD::FP_EXTEND, dl, NVT, Op0),
23517+ DAG.getNode(ISD::FP_EXTEND, dl, NVT, Op1), CC);
23518+ return DVT == VT ? Cmp : DAG.getNode(ISD::TRUNCATE, dl, VT, Cmp);
23519+ }
23520+
23521+ bool IsSignaling = Op.getOpcode() == ISD::STRICT_FSETCCS;
2349423522
2349523523 // If we have a strict compare with a vXi1 result and the input is 128/256
2349623524 // bits we can't use a masked compare unless we have VLX. If we use a wider
@@ -23701,12 +23729,12 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
2370123729
2370223730 // Break 256-bit integer vector compare into smaller ones.
2370323731 if (VT.is256BitVector() && !Subtarget.hasInt256())
23704- return splitIntVSETCC (VT, Op0, Op1, Cond, DAG, dl);
23732+ return splitVSETCC (VT, Op0, Op1, Cond, DAG, dl);
2370523733
2370623734 // Break 512-bit integer vector compare into smaller ones.
2370723735 // TODO: Try harder to use VPCMPx + VPMOV2x?
2370823736 if (VT.is512BitVector())
23709- return splitIntVSETCC (VT, Op0, Op1, Cond, DAG, dl);
23737+ return splitVSETCC (VT, Op0, Op1, Cond, DAG, dl);
2371023738
2371123739 // If we have a limit constant, try to form PCMPGT (signed cmp) to avoid
2371223740 // not-of-PCMPEQ:
0 commit comments