Skip to content

Commit 0cb7c0d

Browse files
LukeBoyercopybara-github
authored andcommitted
Simplify litert_rng type selection, Remove redundant and confusing factory type.
LiteRT-PiperOrigin-RevId: 774993305
1 parent 6da00db commit 0cb7c0d

File tree

3 files changed

+59
-109
lines changed

3 files changed

+59
-109
lines changed

litert/cc/litert_rng.h

Lines changed: 30 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,11 @@ class RandomDevice {
116116

117117
// Abstract base class for generating data of a certain type from a given rng
118118
// device, e.g. populating tensors and the like.
119-
template <typename D, template <typename> typename Dist, typename DeviceBase>
119+
template <typename D, template <typename> typename Dist>
120120
class DataGenerator {
121121
public:
122122
using DataType = D;
123123
using Wide = WideType<D>;
124-
using Device = RandomDevice<DeviceBase>;
125-
126-
virtual DataType operator()(Device& rng) = 0;
127124

128125
// Bounds of distribution.
129126
DataType Max() const { return dist_.max(); }
@@ -134,14 +131,13 @@ class DataGenerator {
134131
};
135132

136133
// A data generator that generates data within a given range.
137-
template <typename D, template <typename> typename Dist, typename DeviceBase>
138-
class RangedGenerator final : public DataGenerator<D, Dist, DeviceBase> {
134+
template <typename D, template <typename> typename Dist>
135+
class RangedGenerator final : public DataGenerator<D, Dist> {
139136
private:
140-
using Base = DataGenerator<D, Dist, DeviceBase>;
137+
using Base = DataGenerator<D, Dist>;
141138

142139
public:
143140
using typename Base::DataType;
144-
using typename Base::Device;
145141
using typename Base::Wide;
146142

147143
RangedGenerator() = default;
@@ -155,30 +151,32 @@ class RangedGenerator final : public DataGenerator<D, Dist, DeviceBase> {
155151
RangedGenerator(RangedGenerator&&) = default;
156152
RangedGenerator& operator=(RangedGenerator&&) = default;
157153

158-
DataType operator()(Device& rng) override { return this->dist_(rng); }
154+
template <typename Rng>
155+
DataType operator()(Rng& rng) {
156+
return this->dist_(rng);
157+
}
159158
};
160159

161160
// A rangeless float generator that casts random bits to the given float type.
162161
// This generally produces higher quality floats more repersentative of the
163162
// target distribution than a ranged generator. Particularly covers more values
164163
// around zero and infinities.
165-
template <typename D, template <typename> typename Dist, typename DeviceBase,
166-
typename Enable = void>
167-
class ReinterpretGenerator final : public DataGenerator<D, Dist, DeviceBase> {};
164+
template <typename D, template <typename> typename Dist, typename Enable = void>
165+
class ReinterpretGenerator final : public DataGenerator<D, Dist> {};
168166

169-
template <typename D, template <typename> typename Dist, typename DeviceBase>
170-
class ReinterpretGenerator<D, Dist, DeviceBase,
167+
template <typename D, template <typename> typename Dist>
168+
class ReinterpretGenerator<D, Dist,
171169
std::enable_if_t<std::is_floating_point_v<D>>>
172-
final : public DataGenerator<D, Dist, DeviceBase> {
170+
final : public DataGenerator<D, Dist> {
173171
private:
174-
using Base = DataGenerator<D, Dist, DeviceBase>;
172+
using Base = DataGenerator<D, Dist>;
175173

176174
public:
177175
using typename Base::DataType;
178-
using typename Base::Device;
179176
using typename Base::Wide;
180177

181-
DataType operator()(Device& rng) override {
178+
template <typename Rng>
179+
DataType operator()(Rng& rng) {
182180
DataType res;
183181
auto bits = rng();
184182
memcpy(&res, &bits, sizeof(res));
@@ -195,72 +193,23 @@ class ReinterpretGenerator<D, Dist, DeviceBase,
195193
ReinterpretGenerator& operator=(ReinterpretGenerator&&) = default;
196194
};
197195

198-
// Recommended distribution for data generators.
199-
template <typename T>
200-
using Uniform =
201-
SelectT<std::is_floating_point<T>, std::uniform_real_distribution<T>,
202-
std::is_integral<T>, std::uniform_int_distribution<T>>;
203-
204-
// Recommended engine for data generators.
205-
using DefaultEngine = std::mt19937_64;
206-
207-
// Factory for creating data generators from just a data type with recommended
208-
// defaults.
209-
template <typename D, template <typename> typename Distribution = Uniform,
210-
typename Engine = DefaultEngine>
211-
class DataGenerators {
212-
// Exotic types not yet supported (e.g. quant, complex, half-precision etc).
213-
static_assert(std::is_floating_point_v<D> || std::is_integral_v<D>);
196+
// DEFAULTS FOR DATA GENERATORS ////////////////////////////////////////////////
214197

215-
private:
216-
using GeneratorBase = DataGenerator<D, Distribution, Engine>;
217-
218-
public:
219-
using Reinterpret = ReinterpretGenerator<D, Uniform, Engine>;
220-
using Ranged = RangedGenerator<D, Uniform, Engine>;
221-
using Dataype = GeneratorBase::DataType;
222-
using Wide = GeneratorBase::Wide;
223-
using RandomDevice = GeneratorBase::Device;
224-
225-
DataGenerators() = default;
226-
DataGenerators(const DataGenerators&) = default;
227-
DataGenerators& operator=(const DataGenerators&) = default;
228-
DataGenerators(DataGenerators&&) = default;
229-
DataGenerators& operator=(DataGenerators&&) = default;
230-
231-
// Create a ranged generator with the given limits.
232-
static auto Generator(Wide min, Wide max) { return Ranged(min, max); }
233-
234-
// Create a rangeless generator. Floating point types will leverage the
235-
// reinterpretation generator, which is recommended.
236-
static auto Generator() {
237-
if constexpr (std::is_floating_point_v<D>) {
238-
return Reinterpret();
239-
} else {
240-
return Ranged();
241-
}
242-
}
198+
template <typename D>
199+
using DefaultGenerator =
200+
SelectT<std::is_floating_point<D>,
201+
ReinterpretGenerator<D, std::uniform_real_distribution>,
202+
std::is_integral<D>,
203+
RangedGenerator<D, std::uniform_int_distribution>>;
243204

244-
// Initialize a random device with the proper types to work with generators.
245-
template <typename... Args>
246-
static auto Device(Args&&... args) {
247-
return RandomDevice(std::forward<Args>(args)...);
248-
}
205+
template <typename D>
206+
using DefaultRangedGenerator =
207+
SelectT<std::is_floating_point<D>,
208+
RangedGenerator<D, std::uniform_real_distribution>,
209+
std::is_integral<D>,
210+
RangedGenerator<D, std::uniform_int_distribution>>;
249211

250-
// Convenience method(s) to create a generator and device in a pair.
251-
static auto GeneratorAndDevice() {
252-
return std::make_pair(Generator(), Device());
253-
}
254-
static auto GeneratorAndDevice(int seed) {
255-
return std::make_pair(Generator(), Device(seed));
256-
}
257-
static auto GeneratorAndDevice(Wide min, Wide max) {
258-
return std::make_pair(Generator(min, max), Device());
259-
}
260-
static auto GeneratorAndDevice(int seed, Wide min, Wide max) {
261-
return std::make_pair(Generator(min, max), Device(seed));
262-
}
263-
};
212+
using DefaultDevice = RandomDevice<std::mt19937>;
264213

265214
} // namespace litert
266215

litert/cc/litert_rng_test.cc

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <cstddef>
2020
#include <cstdint>
2121
#include <random>
22+
#include <type_traits>
2223

2324
#include <gmock/gmock.h>
2425
#include <gtest/gtest.h>
@@ -74,10 +75,14 @@ TEST(LitertRngTestWithCustomRng, NoSeed) {
7475
HasSubstr("DummyRng(seed=<default>,"));
7576
}
7677

77-
using LiteRtRngTest = RngTest<>;
78+
using LiteRtRngTest = RngTest;
7879

7980
TEST_F(LiteRtRngTest, Ints) {
80-
auto [gen, device] = GeneratorAndDevice<int>();
81+
auto device = TracedDevice();
82+
auto gen = DefaultGenerator<int>();
83+
static_assert(
84+
std::is_same_v<decltype(gen),
85+
RangedGenerator<int, std::uniform_int_distribution>>);
8186
for (int i = 0; i < kTestIters; ++i) {
8287
const auto val = gen(device);
8388
ASSERT_LE(val, gen.Max());
@@ -88,7 +93,11 @@ TEST_F(LiteRtRngTest, Ints) {
8893
TEST_F(LiteRtRngTest, IntsWithRange) {
8994
static constexpr auto kMin = 10;
9095
static constexpr auto kMax = 20;
91-
auto [gen, device] = GeneratorAndDevice<int>(kMin, kMax);
96+
auto device = TracedDevice();
97+
auto gen = DefaultGenerator<int>(kMin, kMax);
98+
static_assert(
99+
std::is_same_v<decltype(gen),
100+
RangedGenerator<int, std::uniform_int_distribution>>);
92101
EXPECT_EQ(gen.Max(), kMax);
93102
EXPECT_EQ(gen.Min(), kMin);
94103
for (int i = 0; i < kTestIters; ++i) {
@@ -99,9 +108,13 @@ TEST_F(LiteRtRngTest, IntsWithRange) {
99108
}
100109

101110
TEST_F(LiteRtRngTest, FloatsWithRange) {
102-
static constexpr auto kMin = 10;
103-
static constexpr auto kMax = 20;
104-
auto [gen, device] = GeneratorAndDevice<float>(kMin, kMax);
111+
static constexpr auto kMin = 10.0f;
112+
static constexpr auto kMax = 20.0f;
113+
auto device = TracedDevice();
114+
auto gen = DefaultRangedGenerator<float>(kMin, kMax);
115+
static_assert(
116+
std::is_same_v<decltype(gen),
117+
RangedGenerator<float, std::uniform_real_distribution>>);
105118
EXPECT_EQ(gen.Max(), kMax);
106119
EXPECT_EQ(gen.Min(), kMin);
107120
for (int i = 0; i < kTestIters; ++i) {
@@ -112,7 +125,11 @@ TEST_F(LiteRtRngTest, FloatsWithRange) {
112125
}
113126

114127
TEST_F(LiteRtRngTest, ReinterpretFloat) {
115-
auto [gen, device] = GeneratorAndDevice<float>();
128+
auto device = TracedDevice();
129+
auto gen = DefaultGenerator<float>();
130+
static_assert(std::is_same_v<
131+
decltype(gen),
132+
ReinterpretGenerator<float, std::uniform_real_distribution>>);
116133
for (int i = 0; i < kTestIters; ++i) {
117134
const auto val = gen(device);
118135
ASSERT_FALSE(std::isnan(val));
@@ -123,7 +140,8 @@ TEST_F(LiteRtRngTest, ReinterpretFloat) {
123140
}
124141

125142
TEST_F(LiteRtRngTest, TestWithFuzz) {
126-
auto [gen, device] = GeneratorAndDevice<int>();
143+
auto device = TracedDevice();
144+
auto gen = DefaultGenerator<int>();
127145
for (auto _ :
128146
FuzzBlock(std::chrono::milliseconds(50), kTestIters, kTestIters)) {
129147
const auto val = gen(device);
@@ -133,5 +151,4 @@ TEST_F(LiteRtRngTest, TestWithFuzz) {
133151
}
134152

135153
} // namespace
136-
137154
} // namespace litert

litert/test/rng_fixture.h

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,7 @@ namespace litert::testing {
3232
// set from gtest, which can be configured from command line for
3333
// reproducibility. It can also be used to set up repeated blocks of code for
3434
// fuzzing.
35-
template <template <typename> typename Distribution = Uniform,
36-
typename Engine = DefaultEngine>
3735
class RngTest : public ::testing::Test {
38-
private:
39-
template <typename T>
40-
using Fact = DataGenerators<T, Distribution, Engine>;
41-
4236
public:
4337
void TearDown() override {
4438
for (const auto& block : fuzz_blocks_) {
@@ -49,19 +43,9 @@ class RngTest : public ::testing::Test {
4943
}
5044

5145
protected:
52-
template <typename T, typename... Args>
53-
auto Generator(Args&&... args) {
54-
return Fact<T>::Generator(std::forward<Args>(args)...);
55-
}
56-
template <typename T>
57-
auto Device() {
58-
return TraceSeedInfo(Fact<T>::Device(CurrentSeed()));
59-
}
60-
template <typename T, typename... Args>
61-
auto GeneratorAndDevice(Args&&... generator_args) {
62-
return std::make_pair(
63-
Fact<T>::Generator(std::forward<Args>(generator_args)...),
64-
TraceSeedInfo(Fact<T>::Device(CurrentSeed())));
46+
template <typename Device = DefaultDevice>
47+
auto TracedDevice() {
48+
return TraceSeedInfo(Device(CurrentSeed()));
6549
}
6650

6751
template <typename... Args>

0 commit comments

Comments
 (0)