Skip to content

Commit d8188fb

Browse files
committed
Fix _approx_ftz_d implementations for rcp + rsqrt
Fix the rcp_approx_ftz_d and rsqrt_approx_ftz_d implementations to better match the PTX spec, which states that the inputs and outputs should zero the lower 32-bits of the mantissa.
1 parent 026166f commit d8188fb

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

llvm/lib/Analysis/ConstantFolding.cpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2006,6 +2006,15 @@ static const APFloat FTZPreserveSign(const APFloat &V) {
20062006
return V;
20072007
}
20082008

2009+
// Get only the upper word of the input double in 1.11.20 format
2010+
// by making the lower 32-bits of the mantissa all 0.
2011+
static const APFloat ZeroLower32Bits(const APFloat &V) {
2012+
assert(V.getSizeInBits(V.getSemantics()) == 64);
2013+
uint64_t DoubleBits = V.bitcastToAPInt().getZExtValue();
2014+
DoubleBits &= 0xffffffff00000000;
2015+
return APFloat(V.getSemantics(), APInt(64, DoubleBits, false, false));
2016+
}
2017+
20092018
Constant *ConstantFoldFP(double (*NativeFP)(double), const APFloat &V, Type *Ty,
20102019
bool ShouldFTZPreservingSign = false) {
20112020
llvm_fenv_clearexcept();
@@ -2651,6 +2660,8 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
26512660
bool IsFTZ = nvvm::RCPShouldFTZ(IntrinsicID);
26522661

26532662
auto Denominator = IsFTZ ? FTZPreserveSign(APF) : APF;
2663+
if (IntrinsicID == Intrinsic::nvvm_rcp_approx_ftz_d)
2664+
Denominator = ZeroLower32Bits(Denominator);
26542665
if (IsApprox && Denominator.isZero()) {
26552666
// According to the PTX spec, approximate rcp should return infinity
26562667
// with the same sign as the denominator when dividing by 0.
@@ -2663,6 +2674,8 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
26632674
if (Status == APFloat::opOK || Status == APFloat::opInexact) {
26642675
if (IsFTZ)
26652676
Res = FTZPreserveSign(Res);
2677+
if (IntrinsicID == Intrinsic::nvvm_rcp_approx_ftz_d)
2678+
Res = ZeroLower32Bits(Res);
26662679
return ConstantFP::get(Ty->getContext(), Res);
26672680
}
26682681
return nullptr;
@@ -2680,14 +2693,24 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
26802693
case Intrinsic::nvvm_rsqrt_approx_f: {
26812694
bool IsFTZ = nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID);
26822695
auto V = IsFTZ ? FTZPreserveSign(APF) : APF;
2696+
2697+
if (IntrinsicID == Intrinsic::nvvm_rsqrt_approx_ftz_d)
2698+
V = ZeroLower32Bits(V);
2699+
26832700
APFloat SqrtV(sqrt(V.convertToDouble()));
26842701

2685-
bool lost;
2686-
SqrtV.convert(APF.getSemantics(), APFloat::rmNearestTiesToEven, &lost);
2702+
if (Ty->isFloatTy()) {
2703+
bool lost;
2704+
SqrtV.convert(APF.getSemantics(), APFloat::rmNearestTiesToEven,
2705+
&lost);
2706+
}
26872707

26882708
APFloat Res = APFloat::getOne(APF.getSemantics());
26892709
Res.divide(SqrtV, APFloat::rmNearestTiesToEven);
26902710

2711+
if (IntrinsicID == Intrinsic::nvvm_rsqrt_approx_ftz_d)
2712+
Res = ZeroLower32Bits(Res);
2713+
26912714
// We do not need to flush the output for ftz because it is impossible
26922715
// for 1/sqrt(x) to be a denormal value. If x is the largest fp value,
26932716
// sqrt(x) will be a number with the exponent approximately halved and

llvm/test/Transforms/InstSimplify/const-fold-nvvm-unary-arithmetic.ll

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ define float @test_rcp_approx_ftz_f_0_5() {
551551

552552
define double @test_rcp_approx_ftz_d_neg_subnorm() {
553553
; CHECK-LABEL: define double @test_rcp_approx_ftz_d_neg_subnorm() {
554-
; CHECK-NEXT: ret double 0xC7D0000020000040
554+
; CHECK-NEXT: ret double 0xC7D0000000000000
555555
;
556556
%res = call double @llvm.nvvm.rcp.approx.ftz.d(double 0xB80FFFFFC0000000)
557557
ret double %res
@@ -568,7 +568,7 @@ define float @test_rcp_approx_ftz_f_neg_subnorm() {
568568

569569
define double @test_rcp_approx_ftz_d_pos_subnorm() {
570570
; CHECK-LABEL: define double @test_rcp_approx_ftz_d_pos_subnorm() {
571-
; CHECK-NEXT: ret double 0x47D0000020000040
571+
; CHECK-NEXT: ret double 0x47D0000000000000
572572
;
573573
%res = call double @llvm.nvvm.rcp.approx.ftz.d(double 0x380FFFFFC0000000)
574574
ret double %res
@@ -658,7 +658,7 @@ define float @test_rsqrt_approx_f_1_25() {
658658

659659
define double @test_rsqrt_approx_ftz_d_1_25() {
660660
; CHECK-LABEL: define double @test_rsqrt_approx_ftz_d_1_25() {
661-
; CHECK-NEXT: ret double 0x3FEC9F25C5BFEDD9
661+
; CHECK-NEXT: ret double 0x3FEC9F2500000000
662662
;
663663
%res = call double @llvm.nvvm.rsqrt.approx.ftz.d(double 1.25)
664664
ret double %res
@@ -690,7 +690,7 @@ define float @test_rsqrt_approx_f_pos_subnorm() {
690690

691691
define double @test_rsqrt_approx_ftz_d_pos_subnorm() {
692692
; CHECK-LABEL: define double @test_rsqrt_approx_ftz_d_pos_subnorm() {
693-
; CHECK-NEXT: ret double 0x43E0000010000018
693+
; CHECK-NEXT: ret double 0x43E0000000000000
694694
;
695695
%res = call double @llvm.nvvm.rsqrt.approx.ftz.d(double 0x380FFFFFC0000000)
696696
ret double %res

0 commit comments

Comments
 (0)