@@ -735,6 +735,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
735735 setOperationAction(ISD::FCANONICALIZE, MVT::f16, Custom);
736736 setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Custom);
737737 setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f64, Custom);
738+ setOperationAction(ISD::LRINT, MVT::f16, Expand);
739+ setOperationAction(ISD::LLRINT, MVT::f16, Expand);
738740
739741 setLibcallName(RTLIB::FPROUND_F32_F16, "__truncsfhf2");
740742 setLibcallName(RTLIB::FPEXT_F16_F32, "__extendhfsf2");
@@ -2312,6 +2314,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
23122314 setOperationAction(ISD::FMINIMUMNUM, MVT::f16, Custom);
23132315 setOperationAction(ISD::FP_EXTEND, MVT::f32, Legal);
23142316 setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Legal);
2317+ setOperationAction(ISD::LRINT, MVT::f16, Legal);
2318+ setOperationAction(ISD::LLRINT, MVT::f16, Legal);
23152319
23162320 setCondCodeAction(ISD::SETOEQ, MVT::f16, Expand);
23172321 setCondCodeAction(ISD::SETUNE, MVT::f16, Expand);
@@ -2359,6 +2363,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
23592363 setOperationAction(ISD::FMAXIMUM, MVT::v32f16, Custom);
23602364 setOperationAction(ISD::FMINIMUMNUM, MVT::v32f16, Custom);
23612365 setOperationAction(ISD::FMAXIMUMNUM, MVT::v32f16, Custom);
2366+ setOperationAction(ISD::LRINT, MVT::v32f16, Legal);
2367+ setOperationAction(ISD::LLRINT, MVT::v8f16, Legal);
23622368 }
23632369
23642370 if (Subtarget.hasVLX()) {
@@ -2413,6 +2419,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
24132419 setOperationAction(ISD::FMAXIMUM, MVT::v16f16, Custom);
24142420 setOperationAction(ISD::FMINIMUMNUM, MVT::v16f16, Custom);
24152421 setOperationAction(ISD::FMAXIMUMNUM, MVT::v16f16, Custom);
2422+ setOperationAction(ISD::LRINT, MVT::v8f16, Legal);
2423+ setOperationAction(ISD::LRINT, MVT::v16f16, Legal);
24162424 }
24172425 }
24182426
@@ -34055,8 +34063,15 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
3405534063 case ISD::LRINT:
3405634064 if (N->getValueType(0) == MVT::v2i32) {
3405734065 SDValue Src = N->getOperand(0);
34058- if (Src.getValueType() == MVT::v2f64)
34059- Results.push_back(DAG.getNode(X86ISD::CVTP2SI, dl, MVT::v4i32, Src));
34066+ if (Subtarget.hasFP16() && Src.getValueType() == MVT::v2f16) {
34067+ Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f16, Src,
34068+ DAG.getUNDEF(MVT::v2f16));
34069+ Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v8f16, Src,
34070+ DAG.getUNDEF(MVT::v4f16));
34071+ } else if (Src.getValueType() != MVT::v2f64) {
34072+ return;
34073+ }
34074+ Results.push_back(DAG.getNode(X86ISD::CVTP2SI, dl, MVT::v4i32, Src));
3406034075 return;
3406134076 }
3406234077 [[fallthrough]];
@@ -53640,13 +53655,35 @@ static SDValue combineLRINT_LLRINT(SDNode *N, SelectionDAG &DAG,
5364053655 EVT SrcVT = Src.getValueType();
5364153656 SDLoc DL(N);
5364253657
53643- if (!Subtarget.hasDQI() || !Subtarget.hasVLX() || VT != MVT::v2i64 ||
53644- SrcVT != MVT::v2f32)
53658+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
53659+
53660+ // Let legalize expand this if it isn't a legal type yet.
53661+ if (!TLI.isTypeLegal(VT))
53662+ return SDValue();
53663+
53664+ if ((SrcVT.getScalarType() == MVT::f16 && !Subtarget.hasFP16()) ||
53665+ (SrcVT.getScalarType() == MVT::f32 && !Subtarget.hasDQI()))
5364553666 return SDValue();
5364653667
53647- return DAG.getNode(X86ISD::CVTP2SI, DL, VT,
53648- DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4f32, Src,
53649- DAG.getUNDEF(SrcVT)));
53668+ if (SrcVT == MVT::v2f16) {
53669+ SrcVT = MVT::v4f16;
53670+ Src = DAG.getNode(ISD::CONCAT_VECTORS, DL, SrcVT, Src,
53671+ DAG.getUNDEF(MVT::v2f16));
53672+ }
53673+
53674+ if (SrcVT == MVT::v4f16) {
53675+ SrcVT = MVT::v8f16;
53676+ Src = DAG.getNode(ISD::CONCAT_VECTORS, DL, SrcVT, Src,
53677+ DAG.getUNDEF(MVT::v4f16));
53678+ } else if (SrcVT == MVT::v2f32) {
53679+ SrcVT = MVT::v4f32;
53680+ Src = DAG.getNode(ISD::CONCAT_VECTORS, DL, SrcVT, Src,
53681+ DAG.getUNDEF(MVT::v2f32));
53682+ } else {
53683+ return SDValue();
53684+ }
53685+
53686+ return DAG.getNode(X86ISD::CVTP2SI, DL, VT, Src);
5365053687}
5365153688
5365253689/// Attempt to pre-truncate inputs to arithmetic ops if it will simplify
0 commit comments