@@ -6824,6 +6824,17 @@ SDValue SITargetLowering::lowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
68246824
68256825 SDLoc DL(Op);
68266826
6827+ if (Subtarget->useRealTrue16Insts()) {
6828+ if (getTargetMachine().Options.UnsafeFPMath) {
6829+ SDValue Flags = Op.getOperand(1);
6830+ SDValue Src32 = DAG.getNode(ISD::FP_ROUND, DL, MVT::f32, Src, Flags);
6831+ return DAG.getNode(ISD::FP_ROUND, DL, MVT::f16, Src32, Flags);
6832+ }
6833+
6834+ SDValue FpToFp16 = LowerF64ToF16(Src, MVT::i16, DL, DAG);
6835+ return DAG.getNode(ISD::BITCAST, DL, MVT::f16, FpToFp16);
6836+ }
6837+
68276838 SDValue FpToFp16 = DAG.getNode(ISD::FP_TO_FP16, DL, MVT::i32, Src);
68286839 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToFp16);
68296840 return DAG.getNode(ISD::BITCAST, DL, MVT::f16, Trunc);
@@ -17002,6 +17013,8 @@ SITargetLowering::getRegClassFor(MVT VT, bool isDivergent) const {
1700217013 if (RC == &AMDGPU::VReg_1RegClass && !isDivergent)
1700317014 return Subtarget->isWave64() ? &AMDGPU::SReg_64RegClass
1700417015 : &AMDGPU::SReg_32RegClass;
17016+ if (VT == MVT::f16 && TRI->isVGPRClass(RC))
17017+ return RC;
1700517018 if (!TRI->isSGPRClass(RC) && !isDivergent)
1700617019 return TRI->getEquivalentSGPRClass(RC);
1700717020 if (TRI->isSGPRClass(RC) && isDivergent)
0 commit comments