@@ -711,6 +711,13 @@ SystemZTargetLowering::SystemZTargetLowering(const TargetMachine &TM,
711711 setOperationAction (ISD::BITCAST, MVT::f32 , Custom);
712712 }
713713
714+ // Expand FP16 <=> FP32 conversions to libcalls and handle FP16 loads and
715+ // stores in GPRs.
716+ setOperationAction (ISD::FP16_TO_FP, MVT::f32 , Expand);
717+ setOperationAction (ISD::FP_TO_FP16, MVT::f32 , Expand);
718+ setLoadExtAction (ISD::EXTLOAD, MVT::f32 , MVT::f16 , Expand);
719+ setTruncStoreAction (MVT::f32 , MVT::f16 , Expand);
720+
714721 // VASTART and VACOPY need to deal with the SystemZ-specific varargs
715722 // structure, but VAEND is a no-op.
716723 setOperationAction (ISD::VASTART, MVT::Other, Custom);
@@ -784,6 +791,20 @@ bool SystemZTargetLowering::useSoftFloat() const {
784791 return Subtarget.hasSoftFloat ();
785792}
786793
794+ MVT SystemZTargetLowering::getRegisterTypeForCallingConv (
795+ LLVMContext &Context, CallingConv::ID CC,
796+ EVT VT) const {
797+ // 128-bit single-element vector types are passed like other vectors,
798+ // not like their element type.
799+ if (VT.isVector () && VT.getSizeInBits () == 128 &&
800+ VT.getVectorNumElements () == 1 )
801+ return MVT::v16i8;
802+ // Keep f16 so that they can be recognized and handled.
803+ if (VT == MVT::f16 )
804+ return MVT::f16 ;
805+ return TargetLowering::getRegisterTypeForCallingConv (Context, CC, VT);
806+ }
807+
787808EVT SystemZTargetLowering::getSetCCResultType (const DataLayout &DL,
788809 LLVMContext &, EVT VT) const {
789810 if (!VT.isVector ())
@@ -1602,6 +1623,15 @@ bool SystemZTargetLowering::splitValueIntoRegisterParts(
16021623 return true ;
16031624 }
16041625
1626+ // Convert f16 to f32 (Out-arg).
1627+ if (PartVT == MVT::f16 ) {
1628+ assert (NumParts == 1 && " " );
1629+ SDValue I16Val = DAG.getBitcast (MVT::i16 , Val);
1630+ SDValue I32Val = DAG.getAnyExtOrTrunc (I16Val, DL, MVT::i32 );
1631+ Parts[0 ] = DAG.getBitcast (MVT::f32 , I32Val);
1632+ return true ;
1633+ }
1634+
16051635 return false ;
16061636}
16071637
@@ -1617,6 +1647,18 @@ SDValue SystemZTargetLowering::joinRegisterPartsIntoValue(
16171647 return SDValue ();
16181648}
16191649
1650+ // F32Val holds a f16 value in f32, return it as an f16 (In-arg). The
1651+ // CopyFromReg was made into an f32 as required as FP32 registers are used
1652+ // for arguments, now convert it to f16.
1653+ static SDValue convertF32ToF16 (SDValue F32Val, SelectionDAG &DAG,
1654+ const SDLoc &DL) {
1655+ assert (F32Val->getOpcode () == ISD::CopyFromReg &&
1656+ " Only expecting to handle f16 with CopyFromReg here." );
1657+ SDValue I32Val = DAG.getBitcast (MVT::i32 , F32Val);
1658+ SDValue I16Val = DAG.getAnyExtOrTrunc (I32Val, DL, MVT::i16 );
1659+ return DAG.getBitcast (MVT::f16 , I16Val);
1660+ }
1661+
16201662SDValue SystemZTargetLowering::LowerFormalArguments (
16211663 SDValue Chain, CallingConv::ID CallConv, bool IsVarArg,
16221664 const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &DL,
@@ -1656,6 +1698,7 @@ SDValue SystemZTargetLowering::LowerFormalArguments(
16561698 NumFixedGPRs += 1 ;
16571699 RC = &SystemZ::GR64BitRegClass;
16581700 break ;
1701+ case MVT::f16 :
16591702 case MVT::f32 :
16601703 NumFixedFPRs += 1 ;
16611704 RC = &SystemZ::FP32BitRegClass;
@@ -1680,7 +1723,11 @@ SDValue SystemZTargetLowering::LowerFormalArguments(
16801723
16811724 Register VReg = MRI.createVirtualRegister (RC);
16821725 MRI.addLiveIn (VA.getLocReg (), VReg);
1683- ArgValue = DAG.getCopyFromReg (Chain, DL, VReg, LocVT);
1726+ // Special handling is needed for f16.
1727+ MVT ArgVT = VA.getLocVT () == MVT::f16 ? MVT::f32 : VA.getLocVT ();
1728+ ArgValue = DAG.getCopyFromReg (Chain, DL, VReg, ArgVT);
1729+ if (VA.getLocVT () == MVT::f16 )
1730+ ArgValue = convertF32ToF16 (ArgValue, DAG, DL);
16841731 } else {
16851732 assert (VA.isMemLoc () && " Argument not register or memory" );
16861733
@@ -1700,9 +1747,12 @@ SDValue SystemZTargetLowering::LowerFormalArguments(
17001747 // from this parameter. Unpromoted ints and floats are
17011748 // passed as right-justified 8-byte values.
17021749 SDValue FIN = DAG.getFrameIndex (FI, PtrVT);
1703- if (VA.getLocVT () == MVT::i32 || VA.getLocVT () == MVT::f32 )
1750+ if (VA.getLocVT () == MVT::i32 || VA.getLocVT () == MVT::f32 ||
1751+ VA.getLocVT () == MVT::f16 ) {
1752+ unsigned SlotOffs = VA.getLocVT () == MVT::f16 ? 6 : 4 ;
17041753 FIN = DAG.getNode (ISD::ADD, DL, PtrVT, FIN,
1705- DAG.getIntPtrConstant (4 , DL));
1754+ DAG.getIntPtrConstant (SlotOffs, DL));
1755+ }
17061756 ArgValue = DAG.getLoad (LocVT, DL, Chain, FIN,
17071757 MachinePointerInfo::getFixedStack (MF, FI));
17081758 }
@@ -2121,10 +2171,14 @@ SystemZTargetLowering::LowerCall(CallLoweringInfo &CLI,
21212171 // Copy all of the result registers out of their specified physreg.
21222172 for (CCValAssign &VA : RetLocs) {
21232173 // Copy the value out, gluing the copy to the end of the call sequence.
2174+ // Special handling is needed for f16.
2175+ MVT ArgVT = VA.getLocVT () == MVT::f16 ? MVT::f32 : VA.getLocVT ();
21242176 SDValue RetValue = DAG.getCopyFromReg (Chain, DL, VA.getLocReg (),
2125- VA. getLocVT () , Glue);
2177+ ArgVT , Glue);
21262178 Chain = RetValue.getValue (1 );
21272179 Glue = RetValue.getValue (2 );
2180+ if (VA.getLocVT () == MVT::f16 )
2181+ RetValue = convertF32ToF16 (RetValue, DAG, DL);
21282182
21292183 // Convert the value of the return register into the value that's
21302184 // being returned.
0 commit comments