Skip to content

Commit 26674a6

Browse files
committed
- Created exhaustive test for rsqrt in libc/test/src/math/exhaustive/
- Created unit test template for rsqrt functions that can be called in the future by other precision functions - Changed range of values for unit-test to match the libc/test/src/math/SqrtTest.h - Changed the comments in rsqrtf.h, removed unnecessary code in the comments
1 parent 5b8b21a commit 26674a6

File tree

5 files changed

+123
-45
lines changed

5 files changed

+123
-45
lines changed

libc/src/__support/math/rsqrtf.h

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,10 @@ LIBC_INLINE static constexpr float rsqrtf(float x) {
5858
return FPBits::zero().get_val();
5959
}
6060

61-
// TODO: add integer based implementation when LIBC_TARGET_CPU_HAS_FPU_FLOAT
62-
// is not defined
61+
// TODO: add float based approximation when
62+
// LIBC_TARGET_CPU_HAS_FPU_DOUBLE is not defined
6363
double result = 1.0f / fputil::sqrt<double>(fputil::cast<double>(x));
6464

65-
// Targeted post-corrections to ensure correct rounding in half for specific
66-
// mantissa patterns
67-
/*
68-
const uint32_t half_mantissa = x_abs & 0x3ff;
69-
if (LIBC_UNLIKELY(half_mantissa == 0x011F)) {
70-
result = fputil::multiply_add(result, 0x1.0p-21f, result);
71-
} else if (LIBC_UNLIKELY(half_mantissa == 0x0313)) {
72-
result = fputil::multiply_add(result, -0x1.0p-21f, result);
73-
}*/
74-
7565
return fputil::cast<float>(result);
7666
}
7767

libc/test/src/math/RsqrtTest.h

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
//===-- Utility class to test rsqrt -----------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef LLVM_LIBC_TEST_SRC_MATH_RSQRTTEST_H
10+
#define LLVM_LIBC_TEST_SRC_MATH_RSQRTTEST_H
11+
12+
#include "test/UnitTest/FEnvSafeTest.h"
13+
#include "test/UnitTest/FPMatcher.h"
14+
#include "test/UnitTest/Test.h"
15+
#include "utils/MPFRWrapper/MPFRUtils.h"
16+
17+
namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
18+
19+
template <typename OutType, typename InType = OutType>
20+
class RsqrtTest : public LIBC_NAMESPACE::testing::FEnvSafeTest {
21+
22+
DECLARE_SPECIAL_CONSTANTS(InType)
23+
24+
static constexpr StorageType HIDDEN_BIT =
25+
StorageType(1) << LIBC_NAMESPACE::fputil::FPBits<InType>::FRACTION_LEN;
26+
27+
public:
28+
using RsqrtFunc = OutType (*)(InType);
29+
30+
// Subnormal inputs: probe both power-of-two mantissas and an even sampling
31+
// across the subnormal range.
32+
void test_denormal_values(RsqrtFunc func) {
33+
// Powers of two in the subnormal mantissa space.
34+
for (StorageType mant = 1; mant < HIDDEN_BIT; mant <<= 1) {
35+
FPBits denormal(zero);
36+
denormal.set_mantissa(mant);
37+
InType x = denormal.get_val();
38+
ASSERT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Rsqrt, x, func(x), 0.5);
39+
}
40+
41+
// Even sampling across all subnormals.
42+
constexpr StorageType COUNT = 200'001;
43+
constexpr StorageType STEP = HIDDEN_BIT / COUNT;
44+
for (StorageType i = 0, v = 0; i <= COUNT; ++i, v += STEP) {
45+
InType x = FPBits(i).get_val();
46+
ASSERT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Rsqrt, x, func(x), 0.5);
47+
}
48+
}
49+
50+
// Positive normal range sampling: skip NaNs and negative values.
51+
void test_normal_range(RsqrtFunc func) {
52+
constexpr StorageType COUNT = 200'001;
53+
constexpr StorageType STEP = STORAGE_MAX / COUNT;
54+
for (StorageType i = 0, v = 0; i <= COUNT; ++i, v += STEP) {
55+
FPBits x_bits(v);
56+
InType x = x_bits.get_val();
57+
if (x_bits.is_nan() || (x < static_cast<InType>(0)))
58+
continue;
59+
ASSERT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Rsqrt, x, func(x), 0.5);
60+
}
61+
}
62+
};
63+
64+
#define LIST_RSQRT_TESTS(T, func) \
65+
using LlvmLibcRsqrtTest = RsqrtTest<T, T>; \
66+
TEST_F(LlvmLibcRsqrtTest, DenormalValues) { test_denormal_values(&func); } \
67+
TEST_F(LlvmLibcRsqrtTest, NormalRange) { test_normal_range(&func); }
68+
69+
#define LIST_NARROWING_RSQRT_TESTS(OutType, InType, func) \
70+
using LlvmLibcRsqrtTest = RsqrtTest<OutType, InType>; \
71+
TEST_F(LlvmLibcRsqrtTest, DenormalValues) { test_denormal_values(&func); } \
72+
TEST_F(LlvmLibcRsqrtTest, NormalRange) { test_normal_range(&func); }
73+
74+
#endif // LLVM_LIBC_TEST_SRC_MATH_RSQRTTEST_H

libc/test/src/math/exhaustive/CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,22 @@ add_fp_unittest(
2626
-lpthread
2727
)
2828

29+
add_fp_unittest(
30+
rsqrtf_test
31+
NO_RUN_POSTBUILD
32+
NEED_MPFR
33+
SUITE
34+
libc_math_exhaustive_tests
35+
SRCS
36+
rsqrtf_test.cpp
37+
DEPENDS
38+
.exhaustive_test
39+
libc.src.math.rsqrtf
40+
libc.src.__support.FPUtil.fp_bits
41+
LINK_LIBRARIES
42+
-lpthread
43+
)
44+
2945
add_fp_unittest(
3046
sinf_test
3147
NO_RUN_POSTBUILD
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===-- Exhaustive test for rsqrtf ----------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "exhaustive_test.h"
10+
#include "src/math/rsqrtf.h"
11+
#include "utils/MPFRWrapper/MPFRUtils.h"
12+
13+
using LlvmLibcRsqrtfTest = LIBC_NAMESPACE::testing::FPTest<float>;
14+
15+
using LlvmLibcRsqrtfExhaustiveTest =
16+
LlvmLibcUnaryOpExhaustiveMathTest<float, mpfr::Operation::Rsqrt,
17+
LIBC_NAMESPACE::rsqrtf>;
18+
19+
namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
20+
21+
// Range: [0, Inf]
22+
static constexpr uint32_t POS_START = 0x0000'0000U;
23+
static constexpr uint32_t POS_STOP = 0x7f80'0000U;
24+
25+
TEST_F(LlvmLibcRsqrtfExhaustiveTest, PositiveRange) {
26+
test_full_range_all_roundings(POS_START, POS_STOP);
27+
}

libc/test/src/math/rsqrtf_test.cpp

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,12 @@
1-
//===-- Exhaustive test for rsqrtf ----------------------------------------===//
2-
//
1+
//===-- Unittests for rsqrtf ----------------------------------------------===//
32
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43
// See https://llvm.org/LICENSE.txt for license information.
54
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65
//
76
//===----------------------------------------------------------------------===//
87

9-
#include "src/math/rsqrtf.h"
10-
#include "test/UnitTest/FPMatcher.h"
11-
#include "test/UnitTest/Test.h"
12-
#include "utils/MPFRWrapper/MPFRUtils.h"
13-
14-
using LlvmLibcRsqrtfTest = LIBC_NAMESPACE::testing::FPTest<float>;
15-
16-
namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
8+
#include "RsqrtTest.h"
179

18-
// Range: [0, Inf]
19-
static constexpr uint32_t POS_START = 0x00000000u;
20-
static constexpr uint32_t POS_STOP = 0x7F800000u;
21-
22-
// Range: [-Inf, 0)
23-
// rsqrt(-0.0) is -inf, not the same for mpfr.
24-
static constexpr uint32_t NEG_START = 0x80000001u;
25-
static constexpr uint32_t NEG_STOP = 0xFF800000u;
26-
27-
TEST_F(LlvmLibcRsqrtfTest, PositiveRange) {
28-
for (uint32_t v = POS_START; v <= POS_STOP; ++v) {
29-
float x = FPBits(v).get_val();
30-
EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Rsqrt, x,
31-
LIBC_NAMESPACE::rsqrtf(x), 0.5);
32-
}
33-
}
10+
#include "src/math/rsqrtf.h"
3411

35-
TEST_F(LlvmLibcRsqrtfTest, NegativeRange) {
36-
for (uint32_t v = NEG_START; v <= NEG_STOP; ++v) {
37-
float x = FPBits(v).get_val();
38-
EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Rsqrt, x,
39-
LIBC_NAMESPACE::rsqrtf(x), 0.5);
40-
}
41-
}
12+
LIST_RSQRT_TESTS(float, LIBC_NAMESPACE::rsqrtf)

0 commit comments

Comments
 (0)