@@ -6899,21 +6899,19 @@ SDValue SITargetLowering::lowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
68996899 if (Op.getOpcode() != ISD::FP_ROUND)
69006900 return Op;
69016901
6902- if (Subtarget->has16BitInsts()) {
6903- if (getTargetMachine().Options.UnsafeFPMath) {
6904- SDValue Flags = Op.getOperand(1);
6905- SDValue Src32 = DAG.getNode(ISD::FP_ROUND, DL, MVT::f32, Src, Flags);
6906- return DAG.getNode(ISD::FP_ROUND, DL, MVT::f16, Src32, Flags);
6907- } else {
6908- SDValue FpToFp16 = LowerF64ToF16(Src, DL, DAG);
6909- SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToFp16);
6910- return DAG.getNode(ISD::BITCAST, DL, MVT::f16, Trunc);
6911- }
6912- } else {
6902+ if (!Subtarget->has16BitInsts()) {
69136903 SDValue FpToFp16 = DAG.getNode(ISD::FP_TO_FP16, DL, MVT::i32, Src);
69146904 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToFp16);
69156905 return DAG.getNode(ISD::BITCAST, DL, MVT::f16, Trunc);
69166906 }
6907+ if (getTargetMachine().Options.UnsafeFPMath) {
6908+ SDValue Flags = Op.getOperand(1);
6909+ SDValue Src32 = DAG.getNode(ISD::FP_ROUND, DL, MVT::f32, Src, Flags);
6910+ return DAG.getNode(ISD::FP_ROUND, DL, MVT::f16, Src32, Flags);
6911+ }
6912+ SDValue FpToFp16 = LowerF64ToF16(Src, DL, DAG);
6913+ SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToFp16);
6914+ return DAG.getNode(ISD::BITCAST, DL, MVT::f16, Trunc);
69176915 }
69186916
69196917 assert(DstVT.getScalarType() == MVT::bf16 &&
0 commit comments