@@ -181,8 +181,8 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
181181 setOperationAction (ISD::FSINCOS, MVT::f32 , Expand);
182182 setOperationAction (ISD::FPOW, MVT::f32 , Expand);
183183 setOperationAction (ISD::FREM, MVT::f32 , Expand);
184- setOperationAction (ISD::FP16_TO_FP, MVT::f32 , Expand );
185- setOperationAction (ISD::FP_TO_FP16, MVT::f32 , Expand );
184+ setOperationAction (ISD::FP16_TO_FP, MVT::f32 , Custom );
185+ setOperationAction (ISD::FP_TO_FP16, MVT::f32 , Custom );
186186
187187 if (Subtarget.is64Bit ())
188188 setOperationAction (ISD::FRINT, MVT::f32 , Legal);
@@ -219,7 +219,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
219219 setOperationAction (ISD::FPOW, MVT::f64 , Expand);
220220 setOperationAction (ISD::FREM, MVT::f64 , Expand);
221221 setOperationAction (ISD::FP16_TO_FP, MVT::f64 , Expand);
222- setOperationAction (ISD::FP_TO_FP16, MVT::f64 , Expand );
222+ setOperationAction (ISD::FP_TO_FP16, MVT::f64 , Custom );
223223
224224 if (Subtarget.is64Bit ())
225225 setOperationAction (ISD::FRINT, MVT::f64 , Legal);
@@ -427,6 +427,10 @@ SDValue LoongArchTargetLowering::LowerOperation(SDValue Op,
427427 return lowerBUILD_VECTOR (Op, DAG);
428428 case ISD::VECTOR_SHUFFLE:
429429 return lowerVECTOR_SHUFFLE (Op, DAG);
430+ case ISD::FP_TO_FP16:
431+ return lowerFP_TO_FP16 (Op, DAG);
432+ case ISD::FP16_TO_FP:
433+ return lowerFP16_TO_FP (Op, DAG);
430434 }
431435 return SDValue ();
432436}
@@ -1354,6 +1358,40 @@ SDValue LoongArchTargetLowering::lowerVECTOR_SHUFFLE(SDValue Op,
13541358 return SDValue ();
13551359}
13561360
1361+ SDValue LoongArchTargetLowering::lowerFP_TO_FP16 (SDValue Op,
1362+ SelectionDAG &DAG) const {
1363+ // Custom lower to ensure the libcall return is passed in an FPR on hard
1364+ // float ABIs.
1365+ SDLoc DL (Op);
1366+ MakeLibCallOptions CallOptions;
1367+ SDValue Op0 = Op.getOperand (0 );
1368+ SDValue Chain = SDValue ();
1369+ RTLIB::Libcall LC = RTLIB::getFPROUND (Op0.getValueType (), MVT::f16 );
1370+ SDValue Res;
1371+ std::tie (Res, Chain) =
1372+ makeLibCall (DAG, LC, MVT::f32 , Op0, CallOptions, DL, Chain);
1373+ if (Subtarget.is64Bit ())
1374+ return DAG.getNode (LoongArchISD::MOVFR2GR_S_LA64, DL, MVT::i64 , Res);
1375+ return DAG.getBitcast (MVT::i32 , Res);
1376+ }
1377+
1378+ SDValue LoongArchTargetLowering::lowerFP16_TO_FP (SDValue Op,
1379+ SelectionDAG &DAG) const {
1380+ // Custom lower to ensure the libcall argument is passed in an FPR on hard
1381+ // float ABIs.
1382+ SDLoc DL (Op);
1383+ MakeLibCallOptions CallOptions;
1384+ SDValue Op0 = Op.getOperand (0 );
1385+ SDValue Chain = SDValue ();
1386+ SDValue Arg = Subtarget.is64Bit () ? DAG.getNode (LoongArchISD::MOVGR2FR_W_LA64,
1387+ DL, MVT::f32 , Op0)
1388+ : DAG.getBitcast (MVT::f32 , Op0);
1389+ SDValue Res;
1390+ std::tie (Res, Chain) = makeLibCall (DAG, RTLIB::FPEXT_F16_F32, MVT::f32 , Arg,
1391+ CallOptions, DL, Chain);
1392+ return Res;
1393+ }
1394+
13571395static bool isConstantOrUndef (const SDValue Op) {
13581396 if (Op->isUndef ())
13591397 return true ;
@@ -1656,16 +1694,19 @@ SDValue LoongArchTargetLowering::lowerFP_TO_SINT(SDValue Op,
16561694 SelectionDAG &DAG) const {
16571695
16581696 SDLoc DL (Op);
1697+ SDValue Op0 = Op.getOperand (0 );
1698+
1699+ if (Op0.getValueType () == MVT::f16 )
1700+ Op0 = DAG.getNode (ISD::FP_EXTEND, DL, MVT::f32 , Op0);
16591701
16601702 if (Op.getValueSizeInBits () > 32 && Subtarget.hasBasicF () &&
16611703 !Subtarget.hasBasicD ()) {
1662- SDValue Dst =
1663- DAG.getNode (LoongArchISD::FTINT, DL, MVT::f32 , Op.getOperand (0 ));
1704+ SDValue Dst = DAG.getNode (LoongArchISD::FTINT, DL, MVT::f32 , Op0);
16641705 return DAG.getNode (LoongArchISD::MOVFR2GR_S_LA64, DL, MVT::i64 , Dst);
16651706 }
16661707
16671708 EVT FPTy = EVT::getFloatingPointVT (Op.getValueSizeInBits ());
1668- SDValue Trunc = DAG.getNode (LoongArchISD::FTINT, DL, FPTy, Op. getOperand ( 0 ) );
1709+ SDValue Trunc = DAG.getNode (LoongArchISD::FTINT, DL, FPTy, Op0 );
16691710 return DAG.getNode (ISD::BITCAST, DL, Op.getValueType (), Trunc);
16701711}
16711712
@@ -2848,6 +2889,10 @@ void LoongArchTargetLowering::ReplaceNodeResults(
28482889 EVT FVT = EVT::getFloatingPointVT (N->getValueSizeInBits (0 ));
28492890 if (getTypeAction (*DAG.getContext (), Src.getValueType ()) !=
28502891 TargetLowering::TypeSoftenFloat) {
2892+ if (!isTypeLegal (Src.getValueType ()))
2893+ return ;
2894+ if (Src.getValueType () == MVT::f16 )
2895+ Src = DAG.getNode (ISD::FP_EXTEND, DL, MVT::f32 , Src);
28512896 SDValue Dst = DAG.getNode (LoongArchISD::FTINT, DL, FVT, Src);
28522897 Results.push_back (DAG.getNode (ISD::BITCAST, DL, VT, Dst));
28532898 return ;
@@ -4229,6 +4274,33 @@ performINTRINSIC_WO_CHAINCombine(SDNode *N, SelectionDAG &DAG,
42294274 return SDValue ();
42304275}
42314276
4277+ static SDValue performMOVGR2FR_WCombine (SDNode *N, SelectionDAG &DAG,
4278+ TargetLowering::DAGCombinerInfo &DCI,
4279+ const LoongArchSubtarget &Subtarget) {
4280+ // If the input to MOVGR2FR_W_LA64 is just MOVFR2GR_S_LA64 the the
4281+ // conversion is unnecessary and can be replaced with the
4282+ // MOVFR2GR_S_LA64 operand.
4283+ SDValue Op0 = N->getOperand (0 );
4284+ if (Op0.getOpcode () == LoongArchISD::MOVFR2GR_S_LA64)
4285+ return Op0.getOperand (0 );
4286+ return SDValue ();
4287+ }
4288+
4289+ static SDValue performMOVFR2GR_SCombine (SDNode *N, SelectionDAG &DAG,
4290+ TargetLowering::DAGCombinerInfo &DCI,
4291+ const LoongArchSubtarget &Subtarget) {
4292+ // If the input to MOVFR2GR_S_LA64 is just MOVGR2FR_W_LA64 then the
4293+ // conversion is unnecessary and can be replaced with the MOVGR2FR_W_LA64
4294+ // operand.
4295+ SDValue Op0 = N->getOperand (0 );
4296+ MVT VT = N->getSimpleValueType (0 );
4297+ if (Op0->getOpcode () == LoongArchISD::MOVGR2FR_W_LA64) {
4298+ assert (Op0.getOperand (0 ).getValueType () == VT && " Unexpected value type!" );
4299+ return Op0.getOperand (0 );
4300+ }
4301+ return SDValue ();
4302+ }
4303+
42324304SDValue LoongArchTargetLowering::PerformDAGCombine (SDNode *N,
42334305 DAGCombinerInfo &DCI) const {
42344306 SelectionDAG &DAG = DCI.DAG ;
@@ -4247,6 +4319,10 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
42474319 return performBITREV_WCombine (N, DAG, DCI, Subtarget);
42484320 case ISD::INTRINSIC_WO_CHAIN:
42494321 return performINTRINSIC_WO_CHAINCombine (N, DAG, DCI, Subtarget);
4322+ case LoongArchISD::MOVGR2FR_W_LA64:
4323+ return performMOVGR2FR_WCombine (N, DAG, DCI, Subtarget);
4324+ case LoongArchISD::MOVFR2GR_S_LA64:
4325+ return performMOVFR2GR_SCombine (N, DAG, DCI, Subtarget);
42504326 }
42514327 return SDValue ();
42524328}
@@ -6260,3 +6336,61 @@ bool LoongArchTargetLowering::shouldAlignPointerArgs(CallInst *CI,
62606336
62616337 return true ;
62626338}
6339+
6340+ bool LoongArchTargetLowering::splitValueIntoRegisterParts (
6341+ SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
6342+ unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
6343+ bool IsABIRegCopy = CC.has_value ();
6344+ EVT ValueVT = Val.getValueType ();
6345+
6346+ if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32 ) {
6347+ // Cast the f16 to i16, extend to i32, pad with ones to make a float
6348+ // nan, and cast to f32.
6349+ Val = DAG.getNode (ISD::BITCAST, DL, MVT::i16 , Val);
6350+ Val = DAG.getNode (ISD::ANY_EXTEND, DL, MVT::i32 , Val);
6351+ Val = DAG.getNode (ISD::OR, DL, MVT::i32 , Val,
6352+ DAG.getConstant (0xFFFF0000 , DL, MVT::i32 ));
6353+ Val = DAG.getNode (ISD::BITCAST, DL, MVT::f32 , Val);
6354+ Parts[0 ] = Val;
6355+ return true ;
6356+ }
6357+
6358+ return false ;
6359+ }
6360+
6361+ SDValue LoongArchTargetLowering::joinRegisterPartsIntoValue (
6362+ SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts,
6363+ MVT PartVT, EVT ValueVT, std::optional<CallingConv::ID> CC) const {
6364+ bool IsABIRegCopy = CC.has_value ();
6365+
6366+ if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32 ) {
6367+ SDValue Val = Parts[0 ];
6368+
6369+ // Cast the f32 to i32, truncate to i16, and cast back to f16.
6370+ Val = DAG.getNode (ISD::BITCAST, DL, MVT::i32 , Val);
6371+ Val = DAG.getNode (ISD::TRUNCATE, DL, MVT::i16 , Val);
6372+ Val = DAG.getNode (ISD::BITCAST, DL, ValueVT, Val);
6373+ return Val;
6374+ }
6375+
6376+ return SDValue ();
6377+ }
6378+
6379+ MVT LoongArchTargetLowering::getRegisterTypeForCallingConv (LLVMContext &Context,
6380+ CallingConv::ID CC,
6381+ EVT VT) const {
6382+ // Use f32 to pass f16.
6383+ if (VT == MVT::f16 && Subtarget.hasBasicF ())
6384+ return MVT::f32 ;
6385+
6386+ return TargetLowering::getRegisterTypeForCallingConv (Context, CC, VT);
6387+ }
6388+
6389+ unsigned LoongArchTargetLowering::getNumRegistersForCallingConv (
6390+ LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
6391+ // Use f32 to pass f16.
6392+ if (VT == MVT::f16 && Subtarget.hasBasicF ())
6393+ return 1 ;
6394+
6395+ return TargetLowering::getNumRegistersForCallingConv (Context, CC, VT);
6396+ }
0 commit comments