Skip to content

Commit 815fc44

Browse files
committed
[NVPTX] Consistently check fast-math flags when lowering fsqrt
1 parent f2c2eba commit 815fc44

File tree

8 files changed

+44
-39
lines changed

8 files changed

+44
-39
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ NVPTXDAGToDAGISel::getDivF32Level(const SDNode *N) const {
7171
return Subtarget->getTargetLowering()->getDivF32Level(*MF, *N);
7272
}
7373

74-
bool NVPTXDAGToDAGISel::usePrecSqrtF32() const {
75-
return Subtarget->getTargetLowering()->usePrecSqrtF32();
74+
bool NVPTXDAGToDAGISel::usePrecSqrtF32(const SDNode *N) const {
75+
return Subtarget->getTargetLowering()->usePrecSqrtF32(*MF, N);
7676
}
7777

7878
bool NVPTXDAGToDAGISel::useF32FTZ() const {

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
4444
bool doMulWide;
4545

4646
NVPTX::DivPrecisionLevel getDivF32Level(const SDNode *N) const;
47-
bool usePrecSqrtF32() const;
47+
bool usePrecSqrtF32(const SDNode *N) const;
4848
bool useF32FTZ() const;
4949
bool allowFMA() const;
5050
bool allowUnsafeFPMath() const;

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,23 @@ NVPTXTargetLowering::getDivF32Level(const MachineFunction &MF,
134134
return NVPTX::DivPrecisionLevel::IEEE754;
135135
}
136136

137-
bool NVPTXTargetLowering::usePrecSqrtF32() const {
138-
if (UsePrecSqrtF32.getNumOccurrences() > 0) {
139-
// If nvptx-prec-sqrtf32 is used on the command-line, always honor it
137+
bool NVPTXTargetLowering::usePrecSqrtF32(const MachineFunction &MF,
138+
const SDNode *N) const {
139+
// If nvptx-prec-sqrtf32 is used on the command-line, always honor it
140+
if (UsePrecSqrtF32.getNumOccurrences() > 0)
140141
return UsePrecSqrtF32;
141-
} else {
142-
// Otherwise, use sqrt.approx if fast math is enabled
143-
return !getTargetMachine().Options.UnsafeFPMath;
142+
143+
// Otherwise, use sqrt.approx if fast math is enabled
144+
if (allowUnsafeFPMath(MF))
145+
return false;
146+
147+
if (N) {
148+
const SDNodeFlags Flags = N->getFlags();
149+
if (Flags.hasApproximateFuncs())
150+
return false;
144151
}
152+
153+
return true;
145154
}
146155

147156
bool NVPTXTargetLowering::useF32FTZ(const MachineFunction &MF) const {
@@ -1134,7 +1143,8 @@ SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
11341143
bool &UseOneConst,
11351144
bool Reciprocal) const {
11361145
if (!(Enabled == ReciprocalEstimate::Enabled ||
1137-
(Enabled == ReciprocalEstimate::Unspecified && !usePrecSqrtF32())))
1146+
(Enabled == ReciprocalEstimate::Unspecified &&
1147+
!usePrecSqrtF32(DAG.getMachineFunction()))))
11381148
return SDValue();
11391149

11401150
if (ExtraSteps == ReciprocalEstimate::Unspecified)

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,8 @@ class NVPTXTargetLowering : public TargetLowering {
225225

226226
// Get whether we should use a precise or approximate 32-bit floating point
227227
// sqrt instruction.
228-
bool usePrecSqrtF32() const;
228+
bool usePrecSqrtF32(const MachineFunction &MF,
229+
const SDNode *N = nullptr) const;
229230

230231
// Get whether we should use instructions that flush floating-point denormals
231232
// to sign-preserving zero.

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,6 @@ def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
151151

152152
def doMulWide : Predicate<"doMulWide">;
153153

154-
def do_SQRTF32_APPROX : Predicate<"!usePrecSqrtF32()">;
155-
def do_SQRTF32_RN : Predicate<"usePrecSqrtF32()">;
156-
157154
def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
158155
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
159156
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,15 +1525,18 @@ def INT_NVVM_SQRT_RZ_D : F_MATH_1<"sqrt.rz.f64", F64RT, F64RT, int_nvvm_sqrt_rz_
15251525
def INT_NVVM_SQRT_RM_D : F_MATH_1<"sqrt.rm.f64", F64RT, F64RT, int_nvvm_sqrt_rm_d>;
15261526
def INT_NVVM_SQRT_RP_D : F_MATH_1<"sqrt.rp.f64", F64RT, F64RT, int_nvvm_sqrt_rp_d>;
15271527

1528+
def fsqrt_approx : PatFrags<(ops node:$a),
1529+
[(fsqrt node:$a),
1530+
(int_nvvm_sqrt_f node:$a)], [{
1531+
return !usePrecSqrtF32(N);
1532+
}]>;
1533+
15281534
// nvvm_sqrt intrinsic
1529-
def : Pat<(int_nvvm_sqrt_f f32:$a),
1530-
(INT_NVVM_SQRT_RN_FTZ_F $a)>, Requires<[doF32FTZ, do_SQRTF32_RN]>;
1531-
def : Pat<(int_nvvm_sqrt_f f32:$a),
1532-
(INT_NVVM_SQRT_RN_F $a)>, Requires<[do_SQRTF32_RN]>;
1533-
def : Pat<(int_nvvm_sqrt_f f32:$a),
1534-
(INT_NVVM_SQRT_APPROX_FTZ_F $a)>, Requires<[doF32FTZ]>;
1535-
def : Pat<(int_nvvm_sqrt_f f32:$a),
1536-
(INT_NVVM_SQRT_APPROX_F $a)>;
1535+
def : Pat<(int_nvvm_sqrt_f f32:$a), (INT_NVVM_SQRT_RN_FTZ_F $a)>, Requires<[doF32FTZ]>;
1536+
def : Pat<(int_nvvm_sqrt_f f32:$a), (INT_NVVM_SQRT_RN_F $a)>;
1537+
1538+
def : Pat<(fsqrt_approx f32:$a), (INT_NVVM_SQRT_APPROX_FTZ_F $a)>, Requires<[doF32FTZ]>;
1539+
def : Pat<(fsqrt_approx f32:$a), (INT_NVVM_SQRT_APPROX_F $a)>;
15371540

15381541
//
15391542
// Rsqrt
@@ -1556,20 +1559,14 @@ def: Pat<(fdiv f32imm_1, (int_nvvm_sqrt_approx_f f32:$a)),
15561559
def: Pat<(fdiv f32imm_1, (int_nvvm_sqrt_approx_ftz_f f32:$a)),
15571560
(INT_NVVM_RSQRT_APPROX_FTZ_F $a)>,
15581561
Requires<[doRsqrtOpt]>;
1559-
// same for int_nvvm_sqrt_f when non-precision sqrt is requested
1560-
def: Pat<(fdiv f32imm_1, (int_nvvm_sqrt_f f32:$a)),
1561-
(INT_NVVM_RSQRT_APPROX_F $a)>,
1562-
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
1563-
def: Pat<(fdiv f32imm_1, (int_nvvm_sqrt_f f32:$a)),
1564-
(INT_NVVM_RSQRT_APPROX_FTZ_F $a)>,
1565-
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
15661562

1567-
def: Pat<(fdiv f32imm_1, (fsqrt f32:$a)),
1563+
// same for int_nvvm_sqrt_f when non-precision sqrt is requested
1564+
def: Pat<(fdiv f32imm_1, (fsqrt_approx f32:$a)),
15681565
(INT_NVVM_RSQRT_APPROX_F $a)>,
1569-
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
1570-
def: Pat<(fdiv f32imm_1, (fsqrt f32:$a)),
1566+
Requires<[doRsqrtOpt, doNoF32FTZ]>;
1567+
def: Pat<(fdiv f32imm_1, (fsqrt_approx f32:$a)),
15711568
(INT_NVVM_RSQRT_APPROX_FTZ_F $a)>,
1572-
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
1569+
Requires<[doRsqrtOpt, doF32FTZ]>;
15731570
//
15741571
// Add
15751572
//

llvm/test/CodeGen/NVPTX/fast-math.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ define float @sqrt_div_fast(float %a, float %b) #0 {
2929
; CHECK-EMPTY:
3030
; CHECK-NEXT: // %bb.0:
3131
; CHECK-NEXT: ld.param.b32 %r1, [sqrt_div_fast_param_0];
32-
; CHECK-NEXT: sqrt.rn.f32 %r2, %r1;
32+
; CHECK-NEXT: sqrt.approx.f32 %r2, %r1;
3333
; CHECK-NEXT: ld.param.b32 %r3, [sqrt_div_fast_param_1];
3434
; CHECK-NEXT: div.approx.f32 %r4, %r2, %r3;
3535
; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
@@ -84,7 +84,7 @@ define float @sqrt_div_fast_ftz(float %a, float %b) #0 #1 {
8484
; CHECK-EMPTY:
8585
; CHECK-NEXT: // %bb.0:
8686
; CHECK-NEXT: ld.param.b32 %r1, [sqrt_div_fast_ftz_param_0];
87-
; CHECK-NEXT: sqrt.rn.ftz.f32 %r2, %r1;
87+
; CHECK-NEXT: sqrt.approx.ftz.f32 %r2, %r1;
8888
; CHECK-NEXT: ld.param.b32 %r3, [sqrt_div_fast_ftz_param_1];
8989
; CHECK-NEXT: div.approx.ftz.f32 %r4, %r2, %r3;
9090
; CHECK-NEXT: st.param.b32 [func_retval0], %r4;

llvm/test/CodeGen/NVPTX/sqrt-approx.ll

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ define float @test_sqrt32(float %a) #0 {
8383
; CHECK-EMPTY:
8484
; CHECK-NEXT: // %bb.0:
8585
; CHECK-NEXT: ld.param.b32 %r1, [test_sqrt32_param_0];
86-
; CHECK-NEXT: sqrt.rn.f32 %r2, %r1;
86+
; CHECK-NEXT: sqrt.approx.f32 %r2, %r1;
8787
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
8888
; CHECK-NEXT: ret;
8989
%ret = tail call float @llvm.sqrt.f32(float %a)
@@ -115,7 +115,7 @@ define float @test_sqrt_ftz(float %a) #0 #1 {
115115
; CHECK-EMPTY:
116116
; CHECK-NEXT: // %bb.0:
117117
; CHECK-NEXT: ld.param.b32 %r1, [test_sqrt_ftz_param_0];
118-
; CHECK-NEXT: sqrt.rn.ftz.f32 %r2, %r1;
118+
; CHECK-NEXT: sqrt.approx.ftz.f32 %r2, %r1;
119119
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
120120
; CHECK-NEXT: ret;
121121
%ret = tail call float @llvm.sqrt.f32(float %a)
@@ -240,7 +240,7 @@ define float @test_sqrt32_refined(float %a) #0 #2 {
240240
; CHECK-EMPTY:
241241
; CHECK-NEXT: // %bb.0:
242242
; CHECK-NEXT: ld.param.b32 %r1, [test_sqrt32_refined_param_0];
243-
; CHECK-NEXT: sqrt.rn.f32 %r2, %r1;
243+
; CHECK-NEXT: sqrt.approx.f32 %r2, %r1;
244244
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
245245
; CHECK-NEXT: ret;
246246
%ret = tail call float @llvm.sqrt.f32(float %a)
@@ -352,7 +352,7 @@ define float @test_sqrt32_refined_ftz(float %a) #0 #1 #2 {
352352
; CHECK-EMPTY:
353353
; CHECK-NEXT: // %bb.0:
354354
; CHECK-NEXT: ld.param.b32 %r1, [test_sqrt32_refined_ftz_param_0];
355-
; CHECK-NEXT: sqrt.rn.ftz.f32 %r2, %r1;
355+
; CHECK-NEXT: sqrt.approx.ftz.f32 %r2, %r1;
356356
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
357357
; CHECK-NEXT: ret;
358358
%ret = tail call float @llvm.sqrt.f32(float %a)

0 commit comments

Comments
 (0)