|
| 1 | +//===-- Pseudo-random number generation utilities ---------------*- C++ -*-===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +#ifndef LLVM_LIBC_BENCHMARKS_GPU_RANDOM_H |
| 10 | +#define LLVM_LIBC_BENCHMARKS_GPU_RANDOM_H |
| 11 | + |
| 12 | +#include "hdr/stdint_proxy.h" |
| 13 | +#include "src/__support/CPP/algorithm.h" |
| 14 | +#include "src/__support/CPP/optional.h" |
| 15 | +#include "src/__support/CPP/type_traits.h" |
| 16 | +#include "src/__support/FPUtil/FPBits.h" |
| 17 | +#include "src/__support/macros/attributes.h" |
| 18 | +#include "src/__support/macros/config.h" |
| 19 | +#include "src/__support/macros/properties/types.h" |
| 20 | +#include "src/__support/sign.h" |
| 21 | + |
| 22 | +namespace LIBC_NAMESPACE_DECL { |
| 23 | +namespace benchmarks { |
| 24 | + |
| 25 | +// Pseudo-random number generator (PRNG) that produces unsigned 64-bit, 32-bit, |
| 26 | +// and 16-bit integers. The implementation is based on the xorshift* generator, |
| 27 | +// seeded using SplitMix64 for robust initialization. For more details, see: |
| 28 | +// https://en.wikipedia.org/wiki/Xorshift |
| 29 | +class RandomGenerator { |
| 30 | + uint64_t state; |
| 31 | + |
| 32 | + static LIBC_INLINE uint64_t splitmix64(uint64_t x) noexcept { |
| 33 | + x += 0x9E3779B97F4A7C15ULL; |
| 34 | + x = (x ^ (x >> 30)) * 0xBF58476D1CE4E5B9ULL; |
| 35 | + x = (x ^ (x >> 27)) * 0x94D049BB133111EBULL; |
| 36 | + x = (x ^ (x >> 31)); |
| 37 | + return x ? x : 0x9E3779B97F4A7C15ULL; |
| 38 | + } |
| 39 | + |
| 40 | +public: |
| 41 | + explicit LIBC_INLINE RandomGenerator(uint64_t seed) noexcept |
| 42 | + : state(splitmix64(seed)) {} |
| 43 | + |
| 44 | + LIBC_INLINE uint64_t next64() noexcept { |
| 45 | + uint64_t x = state; |
| 46 | + x ^= x >> 12; |
| 47 | + x ^= x << 25; |
| 48 | + x ^= x >> 27; |
| 49 | + state = x; |
| 50 | + return x * 0x2545F4914F6CDD1DULL; |
| 51 | + } |
| 52 | + |
| 53 | + LIBC_INLINE uint32_t next32() noexcept { |
| 54 | + return static_cast<uint32_t>(next64() >> 32); |
| 55 | + } |
| 56 | + |
| 57 | + LIBC_INLINE uint16_t next16() noexcept { |
| 58 | + return static_cast<uint16_t>(next64() >> 48); |
| 59 | + } |
| 60 | +}; |
| 61 | + |
| 62 | +// Generates random floating-point numbers where the unbiased binary exponent |
| 63 | +// is sampled uniformly in `[min_exp, max_exp]`. The significand bits are |
| 64 | +// always randomized, while the sign is randomized by default but can be fixed. |
| 65 | +// Evenly covers orders of magnitude; never yields Inf/NaN. |
| 66 | +template <typename T> class UniformExponent { |
| 67 | + static_assert(cpp::is_same_v<T, float16> || cpp::is_same_v<T, float> || |
| 68 | + cpp::is_same_v<T, double>, |
| 69 | + "UniformExponent supports float16, float, and double"); |
| 70 | + |
| 71 | + using FPBits = LIBC_NAMESPACE::fputil::FPBits<T>; |
| 72 | + using Storage = typename FPBits::StorageType; |
| 73 | + |
| 74 | +public: |
| 75 | + explicit UniformExponent(int min_exp = -FPBits::EXP_BIAS, |
| 76 | + int max_exp = FPBits::EXP_BIAS, |
| 77 | + cpp::optional<Sign> forced_sign = cpp::nullopt) |
| 78 | + : min_exp(clamp_exponent(cpp::min(min_exp, max_exp))), |
| 79 | + max_exp(clamp_exponent(cpp::max(min_exp, max_exp))), |
| 80 | + forced_sign(forced_sign) {} |
| 81 | + |
| 82 | + LIBC_INLINE T operator()(RandomGenerator &rng) const noexcept { |
| 83 | + // Sample unbiased exponent e uniformly in [min_exp, max_exp] without modulo |
| 84 | + // bias, using rejection sampling |
| 85 | + auto sample_in_range = [&](uint64_t r) -> int32_t { |
| 86 | + const uint64_t range = static_cast<uint64_t>( |
| 87 | + static_cast<int64_t>(max_exp) - static_cast<int64_t>(min_exp) + 1); |
| 88 | + const uint64_t threshold = (-range) % range; |
| 89 | + while (r < threshold) |
| 90 | + r = rng.next64(); |
| 91 | + return static_cast<int32_t>(min_exp + static_cast<int64_t>(r % range)); |
| 92 | + }; |
| 93 | + const int32_t e = sample_in_range(rng.next64()); |
| 94 | + |
| 95 | + // Start from random bits to get random sign and mantissa |
| 96 | + FPBits xbits([&] { |
| 97 | + if constexpr (cpp::is_same_v<T, double>) |
| 98 | + return FPBits(rng.next64()); |
| 99 | + else if constexpr (cpp::is_same_v<T, float>) |
| 100 | + return FPBits(rng.next32()); |
| 101 | + else |
| 102 | + return FPBits(rng.next16()); |
| 103 | + }()); |
| 104 | + |
| 105 | + if (e == -FPBits::EXP_BIAS) { |
| 106 | + // Subnormal: biased exponent must be 0; ensure mantissa != 0 to avoid 0 |
| 107 | + xbits.set_biased_exponent(Storage(0)); |
| 108 | + if (xbits.get_mantissa() == Storage(0)) |
| 109 | + xbits.set_mantissa(Storage(1)); |
| 110 | + } else { |
| 111 | + // Normal: biased exponent in [1, 2 * FPBits::EXP_BIAS] |
| 112 | + const int32_t biased = e + FPBits::EXP_BIAS; |
| 113 | + xbits.set_biased_exponent(static_cast<Storage>(biased)); |
| 114 | + } |
| 115 | + |
| 116 | + if (forced_sign) |
| 117 | + xbits.set_sign(*forced_sign); |
| 118 | + |
| 119 | + return xbits.get_val(); |
| 120 | + } |
| 121 | + |
| 122 | +private: |
| 123 | + static LIBC_INLINE int clamp_exponent(int val) noexcept { |
| 124 | + if (val < -FPBits::EXP_BIAS) |
| 125 | + return -FPBits::EXP_BIAS; |
| 126 | + |
| 127 | + if (val > FPBits::EXP_BIAS) |
| 128 | + return FPBits::EXP_BIAS; |
| 129 | + |
| 130 | + return val; |
| 131 | + } |
| 132 | + |
| 133 | + const int min_exp; |
| 134 | + const int max_exp; |
| 135 | + const cpp::optional<Sign> forced_sign; |
| 136 | +}; |
| 137 | + |
| 138 | +// Generates random floating-point numbers that are uniformly distributed on |
| 139 | +// a linear scale. Values are sampled from `[min_val, max_val)`. |
| 140 | +template <typename T> class UniformLinear { |
| 141 | + static_assert(cpp::is_same_v<T, float16> || cpp::is_same_v<T, float> || |
| 142 | + cpp::is_same_v<T, double>, |
| 143 | + "UniformLinear supports float16, float, and double"); |
| 144 | + |
| 145 | + using FPBits = LIBC_NAMESPACE::fputil::FPBits<T>; |
| 146 | + using Storage = typename FPBits::StorageType; |
| 147 | + |
| 148 | + static constexpr T MAX_NORMAL = FPBits::max_normal().get_val(); |
| 149 | + |
| 150 | +public: |
| 151 | + explicit UniformLinear(T min_val = -MAX_NORMAL, T max_val = MAX_NORMAL) |
| 152 | + : min_val(clamp_val(cpp::min(min_val, max_val))), |
| 153 | + max_val(clamp_val(cpp::max(min_val, max_val))) {} |
| 154 | + |
| 155 | + LIBC_INLINE T operator()(RandomGenerator &rng) const noexcept { |
| 156 | + double u = standard_uniform(rng.next64()); |
| 157 | + double a = static_cast<double>(min_val); |
| 158 | + double b = static_cast<double>(max_val); |
| 159 | + double y = a + (b - a) * u; |
| 160 | + return static_cast<T>(y); |
| 161 | + } |
| 162 | + |
| 163 | +private: |
| 164 | + static LIBC_INLINE T clamp_val(T val) noexcept { |
| 165 | + if (val < -MAX_NORMAL) |
| 166 | + return -MAX_NORMAL; |
| 167 | + |
| 168 | + if (val > MAX_NORMAL) |
| 169 | + return MAX_NORMAL; |
| 170 | + |
| 171 | + return val; |
| 172 | + } |
| 173 | + |
| 174 | + static LIBC_INLINE double standard_uniform(uint64_t x) noexcept { |
| 175 | + constexpr int PREC_BITS = |
| 176 | + LIBC_NAMESPACE::fputil::FPBits<double>::SIG_LEN + 1; |
| 177 | + constexpr int SHIFT_BITS = LIBC_NAMESPACE::fputil::FPBits<double>::EXP_LEN; |
| 178 | + constexpr double INV = 1.0 / static_cast<double>(1ULL << PREC_BITS); |
| 179 | + |
| 180 | + return static_cast<double>(x >> SHIFT_BITS) * INV; |
| 181 | + } |
| 182 | + |
| 183 | + const T min_val; |
| 184 | + const T max_val; |
| 185 | +}; |
| 186 | + |
| 187 | +} // namespace benchmarks |
| 188 | +} // namespace LIBC_NAMESPACE_DECL |
| 189 | + |
| 190 | +#endif |
0 commit comments