@@ -711,6 +711,13 @@ SystemZTargetLowering::SystemZTargetLowering(const TargetMachine &TM,
711711 setOperationAction (ISD::BITCAST, MVT::f32 , Custom);
712712 }
713713
714+ // Expand FP16 <=> FP32 conversions to libcalls and handle FP16 loads and
715+ // stores in GPRs.
716+ setOperationAction (ISD::FP16_TO_FP, MVT::f32 , Expand);
717+ setOperationAction (ISD::FP_TO_FP16, MVT::f32 , Expand);
718+ setLoadExtAction (ISD::EXTLOAD, MVT::f32 , MVT::f16 , Expand);
719+ setTruncStoreAction (MVT::f32 , MVT::f16 , Expand);
720+
714721 // VASTART and VACOPY need to deal with the SystemZ-specific varargs
715722 // structure, but VAEND is a no-op.
716723 setOperationAction (ISD::VASTART, MVT::Other, Custom);
@@ -784,6 +791,20 @@ bool SystemZTargetLowering::useSoftFloat() const {
784791 return Subtarget.hasSoftFloat ();
785792}
786793
794+ MVT SystemZTargetLowering::getRegisterTypeForCallingConv (
795+ LLVMContext &Context, CallingConv::ID CC,
796+ EVT VT) const {
797+ // 128-bit single-element vector types are passed like other vectors,
798+ // not like their element type.
799+ if (VT.isVector () && VT.getSizeInBits () == 128 &&
800+ VT.getVectorNumElements () == 1 )
801+ return MVT::v16i8;
802+ // Keep f16 so that they can be recognized and handled.
803+ if (VT == MVT::f16 )
804+ return MVT::f16 ;
805+ return TargetLowering::getRegisterTypeForCallingConv (Context, CC, VT);
806+ }
807+
787808EVT SystemZTargetLowering::getSetCCResultType (const DataLayout &DL,
788809 LLVMContext &, EVT VT) const {
789810 if (!VT.isVector ())
@@ -1597,6 +1618,15 @@ bool SystemZTargetLowering::splitValueIntoRegisterParts(
15971618 return true ;
15981619 }
15991620
1621+ // Convert f16 to f32 (Out-arg).
1622+ if (PartVT == MVT::f16 ) {
1623+ assert (NumParts == 1 && " " );
1624+ SDValue I16Val = DAG.getBitcast (MVT::i16 , Val);
1625+ SDValue I32Val = DAG.getAnyExtOrTrunc (I16Val, DL, MVT::i32 );
1626+ Parts[0 ] = DAG.getBitcast (MVT::f32 , I32Val);
1627+ return true ;
1628+ }
1629+
16001630 return false ;
16011631}
16021632
@@ -1612,6 +1642,18 @@ SDValue SystemZTargetLowering::joinRegisterPartsIntoValue(
16121642 return SDValue ();
16131643}
16141644
1645+ // F32Val holds a f16 value in f32, return it as an f16 (In-arg). The
1646+ // CopyFromReg was made into an f32 as required as FP32 registers are used
1647+ // for arguments, now convert it to f16.
1648+ static SDValue convertF32ToF16 (SDValue F32Val, SelectionDAG &DAG,
1649+ const SDLoc &DL) {
1650+ assert (F32Val->getOpcode () == ISD::CopyFromReg &&
1651+ " Only expecting to handle f16 with CopyFromReg here." );
1652+ SDValue I32Val = DAG.getBitcast (MVT::i32 , F32Val);
1653+ SDValue I16Val = DAG.getAnyExtOrTrunc (I32Val, DL, MVT::i16 );
1654+ return DAG.getBitcast (MVT::f16 , I16Val);
1655+ }
1656+
16151657SDValue SystemZTargetLowering::LowerFormalArguments (
16161658 SDValue Chain, CallingConv::ID CallConv, bool IsVarArg,
16171659 const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &DL,
@@ -1651,6 +1693,7 @@ SDValue SystemZTargetLowering::LowerFormalArguments(
16511693 NumFixedGPRs += 1 ;
16521694 RC = &SystemZ::GR64BitRegClass;
16531695 break ;
1696+ case MVT::f16 :
16541697 case MVT::f32 :
16551698 NumFixedFPRs += 1 ;
16561699 RC = &SystemZ::FP32BitRegClass;
@@ -1675,7 +1718,11 @@ SDValue SystemZTargetLowering::LowerFormalArguments(
16751718
16761719 Register VReg = MRI.createVirtualRegister (RC);
16771720 MRI.addLiveIn (VA.getLocReg (), VReg);
1678- ArgValue = DAG.getCopyFromReg (Chain, DL, VReg, LocVT);
1721+ // Special handling is needed for f16.
1722+ MVT ArgVT = VA.getLocVT () == MVT::f16 ? MVT::f32 : VA.getLocVT ();
1723+ ArgValue = DAG.getCopyFromReg (Chain, DL, VReg, ArgVT);
1724+ if (VA.getLocVT () == MVT::f16 )
1725+ ArgValue = convertF32ToF16 (ArgValue, DAG, DL);
16791726 } else {
16801727 assert (VA.isMemLoc () && " Argument not register or memory" );
16811728
@@ -1695,9 +1742,12 @@ SDValue SystemZTargetLowering::LowerFormalArguments(
16951742 // from this parameter. Unpromoted ints and floats are
16961743 // passed as right-justified 8-byte values.
16971744 SDValue FIN = DAG.getFrameIndex (FI, PtrVT);
1698- if (VA.getLocVT () == MVT::i32 || VA.getLocVT () == MVT::f32 )
1745+ if (VA.getLocVT () == MVT::i32 || VA.getLocVT () == MVT::f32 ||
1746+ VA.getLocVT () == MVT::f16 ) {
1747+ unsigned SlotOffs = VA.getLocVT () == MVT::f16 ? 6 : 4 ;
16991748 FIN = DAG.getNode (ISD::ADD, DL, PtrVT, FIN,
1700- DAG.getIntPtrConstant (4 , DL));
1749+ DAG.getIntPtrConstant (SlotOffs, DL));
1750+ }
17011751 ArgValue = DAG.getLoad (LocVT, DL, Chain, FIN,
17021752 MachinePointerInfo::getFixedStack (MF, FI));
17031753 }
@@ -2120,10 +2170,14 @@ SystemZTargetLowering::LowerCall(CallLoweringInfo &CLI,
21202170 // Copy all of the result registers out of their specified physreg.
21212171 for (CCValAssign &VA : RetLocs) {
21222172 // Copy the value out, gluing the copy to the end of the call sequence.
2173+ // Special handling is needed for f16.
2174+ MVT ArgVT = VA.getLocVT () == MVT::f16 ? MVT::f32 : VA.getLocVT ();
21232175 SDValue RetValue = DAG.getCopyFromReg (Chain, DL, VA.getLocReg (),
2124- VA. getLocVT () , Glue);
2176+ ArgVT , Glue);
21252177 Chain = RetValue.getValue (1 );
21262178 Glue = RetValue.getValue (2 );
2179+ if (VA.getLocVT () == MVT::f16 )
2180+ RetValue = convertF32ToF16 (RetValue, DAG, DL);
21272181
21282182 // Convert the value of the return register into the value that's
21292183 // being returned.
0 commit comments