From 2a284e412dfbaa28f8e82a52f34ef035b44d095e Mon Sep 17 00:00:00 2001 From: Alex Maclean Date: Fri, 23 May 2025 18:20:50 +0000 Subject: [PATCH] [NVPTX] Add -nvptx-prec-divf32=3 to disable ftz for f32 fdiv --- llvm/lib/Target/NVPTX/NVPTX.h | 1 + llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 16 +- llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 174 +++++++++--------- .../CodeGen/NVPTX/nvptx-prec-divf32-flag.ll | 93 ++++++++++ 4 files changed, 193 insertions(+), 91 deletions(-) create mode 100644 llvm/test/CodeGen/NVPTX/nvptx-prec-divf32-flag.ll diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h index 8464028b8dc76..b7c5a0a5c9983 100644 --- a/llvm/lib/Target/NVPTX/NVPTX.h +++ b/llvm/lib/Target/NVPTX/NVPTX.h @@ -258,6 +258,7 @@ enum class DivPrecisionLevel : unsigned { Approx = 0, Full = 1, IEEE754 = 2, + IEEE754_NoFTZ = 3, }; } // namespace NVPTX diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 51f4682c5ba15..26736467778b0 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -87,13 +87,15 @@ static cl::opt FMAContractLevelOpt( static cl::opt UsePrecDivF32( "nvptx-prec-divf32", cl::Hidden, - cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use" - " IEEE Compliant F32 div.rnd if available."), - cl::values(clEnumValN(NVPTX::DivPrecisionLevel::Approx, "0", - "Use div.approx"), - clEnumValN(NVPTX::DivPrecisionLevel::Full, "1", "Use div.full"), - clEnumValN(NVPTX::DivPrecisionLevel::IEEE754, "2", - "Use IEEE Compliant F32 div.rnd if available")), + cl::desc( + "NVPTX Specific: Override the precision of the lowering for f32 fdiv"), + cl::values( + clEnumValN(NVPTX::DivPrecisionLevel::Approx, "0", "Use div.approx"), + clEnumValN(NVPTX::DivPrecisionLevel::Full, "1", "Use div.full"), + clEnumValN(NVPTX::DivPrecisionLevel::IEEE754, "2", + "Use IEEE Compliant F32 div.rnd if available (default)"), + clEnumValN(NVPTX::DivPrecisionLevel::IEEE754_NoFTZ, "3", + "Use IEEE Compliant F32 div.rnd if available, no FTZ")), cl::init(NVPTX::DivPrecisionLevel::IEEE754)); static cl::opt UsePrecSqrtF32( diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index d0538f579f94e..5076e25dbaa29 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -1222,20 +1222,20 @@ def BFNEG16x2 : FNEG_BF16_F16X2<"neg.bf16x2", v2bf16, Int32Regs, True>; // F64 division // def FRCP64r : - NVPTXInst<(outs Float64Regs:$dst), - (ins Float64Regs:$b), - "rcp.rn.f64 \t$dst, $b;", - [(set f64:$dst, (fdiv f64imm_1, f64:$b))]>; + BasicNVPTXInst<(outs Float64Regs:$dst), + (ins Float64Regs:$b), + "rcp.rn.f64", + [(set f64:$dst, (fdiv f64imm_1, f64:$b))]>; def FDIV64rr : - NVPTXInst<(outs Float64Regs:$dst), - (ins Float64Regs:$a, Float64Regs:$b), - "div.rn.f64 \t$dst, $a, $b;", - [(set f64:$dst, (fdiv f64:$a, f64:$b))]>; + BasicNVPTXInst<(outs Float64Regs:$dst), + (ins Float64Regs:$a, Float64Regs:$b), + "div.rn.f64", + [(set f64:$dst, (fdiv f64:$a, f64:$b))]>; def FDIV64ri : - NVPTXInst<(outs Float64Regs:$dst), - (ins Float64Regs:$a, f64imm:$b), - "div.rn.f64 \t$dst, $a, $b;", - [(set f64:$dst, (fdiv f64:$a, fpimm:$b))]>; + BasicNVPTXInst<(outs Float64Regs:$dst), + (ins Float64Regs:$a, f64imm:$b), + "div.rn.f64", + [(set f64:$dst, (fdiv f64:$a, fpimm:$b))]>; // fdiv will be converted to rcp // fneg (fdiv 1.0, X) => fneg (rcp.rn X) @@ -1253,42 +1253,42 @@ def fdiv_approx : PatFrag<(ops node:$a, node:$b), def FRCP32_approx_r_ftz : - NVPTXInst<(outs Float32Regs:$dst), - (ins Float32Regs:$b), - "rcp.approx.ftz.f32 \t$dst, $b;", - [(set f32:$dst, (fdiv_approx f32imm_1, f32:$b))]>, - Requires<[doF32FTZ]>; + BasicNVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$b), + "rcp.approx.ftz.f32", + [(set f32:$dst, (fdiv_approx f32imm_1, f32:$b))]>, + Requires<[doF32FTZ]>; def FRCP32_approx_r : - NVPTXInst<(outs Float32Regs:$dst), - (ins Float32Regs:$b), - "rcp.approx.f32 \t$dst, $b;", - [(set f32:$dst, (fdiv_approx f32imm_1, f32:$b))]>; + BasicNVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$b), + "rcp.approx.f32", + [(set f32:$dst, (fdiv_approx f32imm_1, f32:$b))]>; // // F32 Approximate division // def FDIV32approxrr_ftz : - NVPTXInst<(outs Float32Regs:$dst), - (ins Float32Regs:$a, Float32Regs:$b), - "div.approx.ftz.f32 \t$dst, $a, $b;", - [(set f32:$dst, (fdiv_approx f32:$a, f32:$b))]>, - Requires<[doF32FTZ]>; + BasicNVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$a, Float32Regs:$b), + "div.approx.ftz.f32", + [(set f32:$dst, (fdiv_approx f32:$a, f32:$b))]>, + Requires<[doF32FTZ]>; def FDIV32approxri_ftz : - NVPTXInst<(outs Float32Regs:$dst), - (ins Float32Regs:$a, f32imm:$b), - "div.approx.ftz.f32 \t$dst, $a, $b;", - [(set f32:$dst, (fdiv_approx f32:$a, fpimm:$b))]>, - Requires<[doF32FTZ]>; + BasicNVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$a, f32imm:$b), + "div.approx.ftz.f32", + [(set f32:$dst, (fdiv_approx f32:$a, fpimm:$b))]>, + Requires<[doF32FTZ]>; def FDIV32approxrr : - NVPTXInst<(outs Float32Regs:$dst), - (ins Float32Regs:$a, Float32Regs:$b), - "div.approx.f32 \t$dst, $a, $b;", - [(set f32:$dst, (fdiv_approx f32:$a, f32:$b))]>; + BasicNVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$a, Float32Regs:$b), + "div.approx.f32", + [(set f32:$dst, (fdiv_approx f32:$a, f32:$b))]>; def FDIV32approxri : - NVPTXInst<(outs Float32Regs:$dst), - (ins Float32Regs:$a, f32imm:$b), - "div.approx.f32 \t$dst, $a, $b;", - [(set f32:$dst, (fdiv_approx f32:$a, fpimm:$b))]>; + BasicNVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$a, f32imm:$b), + "div.approx.f32", + [(set f32:$dst, (fdiv_approx f32:$a, fpimm:$b))]>; // // F32 Semi-accurate reciprocal // @@ -1312,66 +1312,72 @@ def : Pat<(fdiv_full f32imm_1, f32:$b), // F32 Semi-accurate division // def FDIV32rr_ftz : - NVPTXInst<(outs Float32Regs:$dst), - (ins Float32Regs:$a, Float32Regs:$b), - "div.full.ftz.f32 \t$dst, $a, $b;", - [(set f32:$dst, (fdiv_full f32:$a, f32:$b))]>, - Requires<[doF32FTZ]>; + BasicNVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$a, Float32Regs:$b), + "div.full.ftz.f32", + [(set f32:$dst, (fdiv_full f32:$a, f32:$b))]>, + Requires<[doF32FTZ]>; def FDIV32ri_ftz : - NVPTXInst<(outs Float32Regs:$dst), - (ins Float32Regs:$a, f32imm:$b), - "div.full.ftz.f32 \t$dst, $a, $b;", - [(set f32:$dst, (fdiv_full f32:$a, fpimm:$b))]>, - Requires<[doF32FTZ]>; + BasicNVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$a, f32imm:$b), + "div.full.ftz.f32", + [(set f32:$dst, (fdiv_full f32:$a, fpimm:$b))]>, + Requires<[doF32FTZ]>; def FDIV32rr : - NVPTXInst<(outs Float32Regs:$dst), - (ins Float32Regs:$a, Float32Regs:$b), - "div.full.f32 \t$dst, $a, $b;", - [(set f32:$dst, (fdiv_full f32:$a, f32:$b))]>; + BasicNVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$a, Float32Regs:$b), + "div.full.f32", + [(set f32:$dst, (fdiv_full f32:$a, f32:$b))]>; def FDIV32ri : - NVPTXInst<(outs Float32Regs:$dst), - (ins Float32Regs:$a, f32imm:$b), - "div.full.f32 \t$dst, $a, $b;", - [(set f32:$dst, (fdiv_full f32:$a, fpimm:$b))]>; + BasicNVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$a, f32imm:$b), + "div.full.f32", + [(set f32:$dst, (fdiv_full f32:$a, fpimm:$b))]>; // // F32 Accurate reciprocal // + +def fdiv_ftz : PatFrag<(ops node:$a, node:$b), + (fdiv node:$a, node:$b), [{ + return getDivF32Level(N) == NVPTX::DivPrecisionLevel::IEEE754; +}]>; + def FRCP32r_prec_ftz : - NVPTXInst<(outs Float32Regs:$dst), - (ins Float32Regs:$b), - "rcp.rn.ftz.f32 \t$dst, $b;", - [(set f32:$dst, (fdiv f32imm_1, f32:$b))]>, - Requires<[doF32FTZ]>; + BasicNVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$b), + "rcp.rn.ftz.f32", + [(set f32:$dst, (fdiv_ftz f32imm_1, f32:$b))]>, + Requires<[doF32FTZ]>; def FRCP32r_prec : - NVPTXInst<(outs Float32Regs:$dst), - (ins Float32Regs:$b), - "rcp.rn.f32 \t$dst, $b;", - [(set f32:$dst, (fdiv f32imm_1, f32:$b))]>; + BasicNVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$b), + "rcp.rn.f32", + [(set f32:$dst, (fdiv f32imm_1, f32:$b))]>; // // F32 Accurate division // def FDIV32rr_prec_ftz : - NVPTXInst<(outs Float32Regs:$dst), - (ins Float32Regs:$a, Float32Regs:$b), - "div.rn.ftz.f32 \t$dst, $a, $b;", - [(set f32:$dst, (fdiv f32:$a, f32:$b))]>, - Requires<[doF32FTZ]>; + BasicNVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$a, Float32Regs:$b), + "div.rn.ftz.f32", + [(set f32:$dst, (fdiv_ftz f32:$a, f32:$b))]>, + Requires<[doF32FTZ]>; def FDIV32ri_prec_ftz : - NVPTXInst<(outs Float32Regs:$dst), - (ins Float32Regs:$a, f32imm:$b), - "div.rn.ftz.f32 \t$dst, $a, $b;", - [(set f32:$dst, (fdiv f32:$a, fpimm:$b))]>, - Requires<[doF32FTZ]>; + BasicNVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$a, f32imm:$b), + "div.rn.ftz.f32", + [(set f32:$dst, (fdiv_ftz f32:$a, fpimm:$b))]>, + Requires<[doF32FTZ]>; def FDIV32rr_prec : - NVPTXInst<(outs Float32Regs:$dst), - (ins Float32Regs:$a, Float32Regs:$b), - "div.rn.f32 \t$dst, $a, $b;", - [(set f32:$dst, (fdiv f32:$a, f32:$b))]>; + BasicNVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$a, Float32Regs:$b), + "div.rn.f32", + [(set f32:$dst, (fdiv f32:$a, f32:$b))]>; def FDIV32ri_prec : - NVPTXInst<(outs Float32Regs:$dst), - (ins Float32Regs:$a, f32imm:$b), - "div.rn.f32 \t$dst, $a, $b;", - [(set f32:$dst, (fdiv f32:$a, fpimm:$b))]>; + BasicNVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$a, f32imm:$b), + "div.rn.f32", + [(set f32:$dst, (fdiv f32:$a, fpimm:$b))]>; // // FMA diff --git a/llvm/test/CodeGen/NVPTX/nvptx-prec-divf32-flag.ll b/llvm/test/CodeGen/NVPTX/nvptx-prec-divf32-flag.ll new file mode 100644 index 0000000000000..aaa3dfa86b1d1 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/nvptx-prec-divf32-flag.ll @@ -0,0 +1,93 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s -verify-machineinstrs -nvptx-prec-divf32=0 | FileCheck %s --check-prefix=APPROX +; RUN: llc < %s -verify-machineinstrs -nvptx-prec-divf32=1 | FileCheck %s --check-prefix=FULL +; RUN: llc < %s -verify-machineinstrs -nvptx-prec-divf32=2 | FileCheck %s --check-prefixes=IEEE,FTZ +; RUN: llc < %s -verify-machineinstrs -nvptx-prec-divf32=3 | FileCheck %s --check-prefixes=IEEE,NOFTZ + +target triple = "nvptx64-nvidia-cuda" + +define float @div_ftz(float %a, float %b) "denormal-fp-math-f32" = "preserve-sign" { +; APPROX-LABEL: div_ftz( +; APPROX: { +; APPROX-NEXT: .reg .b32 %r<4>; +; APPROX-EMPTY: +; APPROX-NEXT: // %bb.0: +; APPROX-NEXT: ld.param.b32 %r1, [div_ftz_param_0]; +; APPROX-NEXT: ld.param.b32 %r2, [div_ftz_param_1]; +; APPROX-NEXT: div.approx.ftz.f32 %r3, %r1, %r2; +; APPROX-NEXT: st.param.b32 [func_retval0], %r3; +; APPROX-NEXT: ret; +; +; FULL-LABEL: div_ftz( +; FULL: { +; FULL-NEXT: .reg .b32 %r<4>; +; FULL-EMPTY: +; FULL-NEXT: // %bb.0: +; FULL-NEXT: ld.param.b32 %r1, [div_ftz_param_0]; +; FULL-NEXT: ld.param.b32 %r2, [div_ftz_param_1]; +; FULL-NEXT: div.full.ftz.f32 %r3, %r1, %r2; +; FULL-NEXT: st.param.b32 [func_retval0], %r3; +; FULL-NEXT: ret; +; +; FTZ-LABEL: div_ftz( +; FTZ: { +; FTZ-NEXT: .reg .b32 %r<4>; +; FTZ-EMPTY: +; FTZ-NEXT: // %bb.0: +; FTZ-NEXT: ld.param.b32 %r1, [div_ftz_param_0]; +; FTZ-NEXT: ld.param.b32 %r2, [div_ftz_param_1]; +; FTZ-NEXT: div.rn.ftz.f32 %r3, %r1, %r2; +; FTZ-NEXT: st.param.b32 [func_retval0], %r3; +; FTZ-NEXT: ret; +; +; NOFTZ-LABEL: div_ftz( +; NOFTZ: { +; NOFTZ-NEXT: .reg .b32 %r<4>; +; NOFTZ-EMPTY: +; NOFTZ-NEXT: // %bb.0: +; NOFTZ-NEXT: ld.param.b32 %r1, [div_ftz_param_0]; +; NOFTZ-NEXT: ld.param.b32 %r2, [div_ftz_param_1]; +; NOFTZ-NEXT: div.rn.f32 %r3, %r1, %r2; +; NOFTZ-NEXT: st.param.b32 [func_retval0], %r3; +; NOFTZ-NEXT: ret; + %val = fdiv float %a, %b + ret float %val +} + + +define float @div(float %a, float %b) { +; APPROX-LABEL: div( +; APPROX: { +; APPROX-NEXT: .reg .b32 %r<4>; +; APPROX-EMPTY: +; APPROX-NEXT: // %bb.0: +; APPROX-NEXT: ld.param.b32 %r1, [div_param_0]; +; APPROX-NEXT: ld.param.b32 %r2, [div_param_1]; +; APPROX-NEXT: div.approx.f32 %r3, %r1, %r2; +; APPROX-NEXT: st.param.b32 [func_retval0], %r3; +; APPROX-NEXT: ret; +; +; FULL-LABEL: div( +; FULL: { +; FULL-NEXT: .reg .b32 %r<4>; +; FULL-EMPTY: +; FULL-NEXT: // %bb.0: +; FULL-NEXT: ld.param.b32 %r1, [div_param_0]; +; FULL-NEXT: ld.param.b32 %r2, [div_param_1]; +; FULL-NEXT: div.full.f32 %r3, %r1, %r2; +; FULL-NEXT: st.param.b32 [func_retval0], %r3; +; FULL-NEXT: ret; +; +; IEEE-LABEL: div( +; IEEE: { +; IEEE-NEXT: .reg .b32 %r<4>; +; IEEE-EMPTY: +; IEEE-NEXT: // %bb.0: +; IEEE-NEXT: ld.param.b32 %r1, [div_param_0]; +; IEEE-NEXT: ld.param.b32 %r2, [div_param_1]; +; IEEE-NEXT: div.rn.f32 %r3, %r1, %r2; +; IEEE-NEXT: st.param.b32 [func_retval0], %r3; +; IEEE-NEXT: ret; + %val = fdiv float %a, %b + ret float %val +}