@@ -1711,6 +1711,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
17111711 setOperationPromotedToType(Opc, MVT::v8f16, MVT::v8f32);
17121712 setOperationPromotedToType(Opc, MVT::v16f16, MVT::v16f32);
17131713 }
1714+ setOperationAction(ISD::SETCC, MVT::v8f16, Custom);
1715+ setOperationAction(ISD::SETCC, MVT::v16f16, Custom);
17141716 }
17151717
17161718 // This block controls legalization of the mask vector sizes that are
@@ -2046,6 +2048,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
20462048 setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v16f32, Custom);
20472049 for (unsigned Opc : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV})
20482050 setOperationPromotedToType(Opc, MVT::v32f16, MVT::v32f32);
2051+ setOperationAction(ISD::SETCC, MVT::v32f16, Custom);
20492052
20502053 for (auto VT : { MVT::v16i32, MVT::v8i64, MVT::v16f32, MVT::v8f64 }) {
20512054 setOperationAction(ISD::MLOAD, VT, Legal);
@@ -2401,6 +2404,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
24012404 setOperationPromotedToType(Opc, MVT::v8bf16, MVT::v8f32);
24022405 setOperationPromotedToType(Opc, MVT::v16bf16, MVT::v16f32);
24032406 }
2407+ setOperationAction(ISD::SETCC, MVT::v8bf16, Custom);
2408+ setOperationAction(ISD::SETCC, MVT::v16bf16, Custom);
24042409 setOperationAction(ISD::FP_ROUND, MVT::v8bf16, Custom);
24052410 addLegalFPImmediate(APFloat::getZero(APFloat::BFloat()));
24062411 }
@@ -2411,6 +2416,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
24112416 setF16Action(MVT::v32bf16, Expand);
24122417 for (unsigned Opc : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV})
24132418 setOperationPromotedToType(Opc, MVT::v32bf16, MVT::v32f32);
2419+ setOperationAction(ISD::SETCC, MVT::v32bf16, Custom);
24142420 setOperationAction(ISD::BUILD_VECTOR, MVT::v32bf16, Custom);
24152421 setOperationAction(ISD::FP_ROUND, MVT::v16bf16, Custom);
24162422 setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v32bf16, Custom);
@@ -23396,14 +23402,12 @@ static unsigned translateX86FSETCC(ISD::CondCode SetCCOpcode, SDValue &Op0,
2339623402 return SSECC;
2339723403}
2339823404
23399- /// Break a VSETCC 256-bit integer VSETCC into two new 128 ones and then
23405+ /// Break a VSETCC 256/512 -bit vector into two new 128/256 ones and then
2340023406/// concatenate the result back.
23401- static SDValue splitIntVSETCC(EVT VT, SDValue LHS, SDValue RHS,
23402- ISD::CondCode Cond, SelectionDAG &DAG,
23403- const SDLoc &dl) {
23404- assert(VT.isInteger() && VT == LHS.getValueType() &&
23405- VT == RHS.getValueType() && "Unsupported VTs!");
23406-
23407+ static SDValue splitVSETCC(EVT VT, SDValue LHS, SDValue RHS, ISD::CondCode Cond,
23408+ SelectionDAG &DAG, const SDLoc &dl) {
23409+ assert(VT.isInteger() && LHS.getValueType() == RHS.getValueType() &&
23410+ "Unsupported VTs!");
2340723411 SDValue CC = DAG.getCondCode(Cond);
2340823412
2340923413 // Extract the LHS Lo/Hi vectors
@@ -23544,18 +23548,43 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
2354423548 SDValue CC = Op.getOperand(IsStrict ? 3 : 2);
2354523549 MVT VT = Op->getSimpleValueType(0);
2354623550 ISD::CondCode Cond = cast<CondCodeSDNode>(CC)->get();
23547- bool isFP = Op1 .getSimpleValueType().isFloatingPoint ();
23551+ MVT OpVT = Op0 .getSimpleValueType();
2354823552 SDLoc dl(Op);
2354923553
23550- if (isFP ) {
23551- MVT EltVT = Op0.getSimpleValueType() .getVectorElementType();
23554+ if (OpVT.isFloatingPoint() ) {
23555+ MVT EltVT = OpVT .getVectorElementType();
2355223556 assert(EltVT == MVT::bf16 || EltVT == MVT::f16 || EltVT == MVT::f32 ||
2355323557 EltVT == MVT::f64);
23554- if (isSoftF16(EltVT, Subtarget))
23555- return SDValue();
2355623558
23557- bool IsSignaling = Op.getOpcode() == ISD::STRICT_FSETCCS;
2355823559 SDValue Chain = IsStrict ? Op.getOperand(0) : SDValue();
23560+ if (isSoftF16(EltVT, Subtarget)) {
23561+ // Break 256-bit FP vector compare into smaller ones.
23562+ if (OpVT.is256BitVector() && !Subtarget.useAVX512Regs())
23563+ return splitVSETCC(VT, Op0, Op1, Cond, DAG, dl);
23564+
23565+ // Break 512-bit FP vector compare into smaller ones.
23566+ if (OpVT.is512BitVector())
23567+ return splitVSETCC(VT, Op0, Op1, Cond, DAG, dl);
23568+
23569+ MVT NVT = OpVT.changeVectorElementType(MVT::f32);
23570+ if (IsStrict) {
23571+ Op0 = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NVT, MVT::Other},
23572+ {Chain, Op0});
23573+ Op1 = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NVT, MVT::Other},
23574+ {Chain, Op1});
23575+ return DAG.getNode(Op.getOpcode(), dl, {VT, MVT::Other},
23576+ {Chain, Op0, Op1, CC});
23577+ }
23578+ MVT DVT = VT.getVectorElementType() == MVT::i16
23579+ ? VT.changeVectorElementType(MVT::i32)
23580+ : VT;
23581+ SDValue Cmp = DAG.getNode(Op.getOpcode(), dl, DVT,
23582+ DAG.getNode(ISD::FP_EXTEND, dl, NVT, Op0),
23583+ DAG.getNode(ISD::FP_EXTEND, dl, NVT, Op1), CC);
23584+ return DVT == VT ? Cmp : DAG.getNode(ISD::TRUNCATE, dl, VT, Cmp);
23585+ }
23586+
23587+ bool IsSignaling = Op.getOpcode() == ISD::STRICT_FSETCCS;
2355923588
2356023589 // If we have a strict compare with a vXi1 result and the input is 128/256
2356123590 // bits we can't use a masked compare unless we have VLX. If we use a wider
@@ -23766,12 +23795,12 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
2376623795
2376723796 // Break 256-bit integer vector compare into smaller ones.
2376823797 if (VT.is256BitVector() && !Subtarget.hasInt256())
23769- return splitIntVSETCC (VT, Op0, Op1, Cond, DAG, dl);
23798+ return splitVSETCC (VT, Op0, Op1, Cond, DAG, dl);
2377023799
2377123800 // Break 512-bit integer vector compare into smaller ones.
2377223801 // TODO: Try harder to use VPCMPx + VPMOV2x?
2377323802 if (VT.is512BitVector())
23774- return splitIntVSETCC (VT, Op0, Op1, Cond, DAG, dl);
23803+ return splitVSETCC (VT, Op0, Op1, Cond, DAG, dl);
2377523804
2377623805 // If we have a limit constant, try to form PCMPGT (signed cmp) to avoid
2377723806 // not-of-PCMPEQ:
0 commit comments