@@ -2521,33 +2521,11 @@ static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
25212521 return DAG.getFPExtendOrRound (Res, DL, VT);
25222522}
25232523
2524- SDValue NVPTXTargetLowering::LowerFADD (SDValue Op, SelectionDAG &DAG) const {
2525- // No fma.ftz for bf16, so fall back to promotion
2526- if (useF32FTZ (DAG.getMachineFunction ())) {
2527- return PromoteBinOpToF32 (Op.getNode (), DAG);
2528- }
2529-
2530- // Legal
2531- return Op;
2532- }
2533-
2534- SDValue NVPTXTargetLowering::LowerFSUB (SDValue Op, SelectionDAG &DAG) const {
2535- // No fma.ftz for bf16, so fall back to promotion
2536- if (useF32FTZ (DAG.getMachineFunction ())) {
2537- return PromoteBinOpToF32 (Op.getNode (), DAG);
2538- }
2539-
2540- // Legal
2541- return Op;
2542- }
2543-
2544- SDValue NVPTXTargetLowering::LowerFMUL (SDValue Op, SelectionDAG &DAG) const {
2545- // No fma.ftz for bf16, so fall back to promotion
2524+ SDValue NVPTXTargetLowering::PromoteBinOpIfF32FTZ (SDValue Op,
2525+ SelectionDAG &DAG) const {
25462526 if (useF32FTZ (DAG.getMachineFunction ())) {
25472527 return PromoteBinOpToF32 (Op.getNode (), DAG);
25482528 }
2549-
2550- // Legal
25512529 return Op;
25522530}
25532531
@@ -2743,11 +2721,10 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
27432721 case ISD::CopyToReg:
27442722 return LowerCopyToReg_128 (Op, DAG);
27452723 case ISD::FADD:
2746- return LowerFADD (Op, DAG);
27472724 case ISD::FSUB:
2748- return LowerFSUB (Op, DAG);
27492725 case ISD::FMUL:
2750- return LowerFMUL (Op, DAG);
2726+ // Used only for bf16 on SM80, where we select fma for non-ftz operation
2727+ return PromoteBinOpIfF32FTZ (Op, DAG);
27512728
27522729 default :
27532730 llvm_unreachable (" Custom lowering not defined for operation" );
0 commit comments