@@ -116,14 +116,11 @@ class RandomDevice {
116116
117117// Abstract base class for generating data of a certain type from a given rng
118118// device, e.g. populating tensors and the like.
119- template <typename D, template <typename > typename Dist, typename DeviceBase >
119+ template <typename D, template <typename > typename Dist>
120120class DataGenerator {
121121 public:
122122 using DataType = D;
123123 using Wide = WideType<D>;
124- using Device = RandomDevice<DeviceBase>;
125-
126- virtual DataType operator ()(Device& rng) = 0;
127124
128125 // Bounds of distribution.
129126 DataType Max () const { return dist_.max (); }
@@ -134,14 +131,13 @@ class DataGenerator {
134131};
135132
136133// A data generator that generates data within a given range.
137- template <typename D, template <typename > typename Dist, typename DeviceBase >
138- class RangedGenerator final : public DataGenerator<D, Dist, DeviceBase > {
134+ template <typename D, template <typename > typename Dist>
135+ class RangedGenerator final : public DataGenerator<D, Dist> {
139136 private:
140- using Base = DataGenerator<D, Dist, DeviceBase >;
137+ using Base = DataGenerator<D, Dist>;
141138
142139 public:
143140 using typename Base::DataType;
144- using typename Base::Device;
145141 using typename Base::Wide;
146142
147143 RangedGenerator () = default ;
@@ -155,30 +151,32 @@ class RangedGenerator final : public DataGenerator<D, Dist, DeviceBase> {
155151 RangedGenerator (RangedGenerator&&) = default ;
156152 RangedGenerator& operator =(RangedGenerator&&) = default ;
157153
158- DataType operator ()(Device& rng) override { return this ->dist_ (rng); }
154+ template <typename Rng>
155+ DataType operator ()(Rng& rng) {
156+ return this ->dist_ (rng);
157+ }
159158};
160159
161160// A rangeless float generator that casts random bits to the given float type.
162161// This generally produces higher quality floats more repersentative of the
163162// target distribution than a ranged generator. Particularly covers more values
164163// around zero and infinities.
165- template <typename D, template <typename > typename Dist, typename DeviceBase,
166- typename Enable = void >
167- class ReinterpretGenerator final : public DataGenerator<D, Dist, DeviceBase> {};
164+ template <typename D, template <typename > typename Dist, typename Enable = void >
165+ class ReinterpretGenerator final : public DataGenerator<D, Dist> {};
168166
169- template <typename D, template <typename > typename Dist, typename DeviceBase >
170- class ReinterpretGenerator <D, Dist, DeviceBase,
167+ template <typename D, template <typename > typename Dist>
168+ class ReinterpretGenerator <D, Dist,
171169 std::enable_if_t <std::is_floating_point_v<D>>>
172- final : public DataGenerator<D, Dist, DeviceBase > {
170+ final : public DataGenerator<D, Dist> {
173171 private:
174- using Base = DataGenerator<D, Dist, DeviceBase >;
172+ using Base = DataGenerator<D, Dist>;
175173
176174 public:
177175 using typename Base::DataType;
178- using typename Base::Device;
179176 using typename Base::Wide;
180177
181- DataType operator ()(Device& rng) override {
178+ template <typename Rng>
179+ DataType operator ()(Rng& rng) {
182180 DataType res;
183181 auto bits = rng ();
184182 memcpy (&res, &bits, sizeof (res));
@@ -195,72 +193,23 @@ class ReinterpretGenerator<D, Dist, DeviceBase,
195193 ReinterpretGenerator& operator =(ReinterpretGenerator&&) = default ;
196194};
197195
198- // Recommended distribution for data generators.
199- template <typename T>
200- using Uniform =
201- SelectT<std::is_floating_point<T>, std::uniform_real_distribution<T>,
202- std::is_integral<T>, std::uniform_int_distribution<T>>;
203-
204- // Recommended engine for data generators.
205- using DefaultEngine = std::mt19937_64;
206-
207- // Factory for creating data generators from just a data type with recommended
208- // defaults.
209- template <typename D, template <typename > typename Distribution = Uniform,
210- typename Engine = DefaultEngine>
211- class DataGenerators {
212- // Exotic types not yet supported (e.g. quant, complex, half-precision etc).
213- static_assert (std::is_floating_point_v<D> || std::is_integral_v<D>);
196+ // DEFAULTS FOR DATA GENERATORS ////////////////////////////////////////////////
214197
215- private:
216- using GeneratorBase = DataGenerator<D, Distribution, Engine>;
217-
218- public:
219- using Reinterpret = ReinterpretGenerator<D, Uniform, Engine>;
220- using Ranged = RangedGenerator<D, Uniform, Engine>;
221- using Dataype = GeneratorBase::DataType;
222- using Wide = GeneratorBase::Wide;
223- using RandomDevice = GeneratorBase::Device;
224-
225- DataGenerators () = default ;
226- DataGenerators (const DataGenerators&) = default ;
227- DataGenerators& operator =(const DataGenerators&) = default ;
228- DataGenerators (DataGenerators&&) = default ;
229- DataGenerators& operator =(DataGenerators&&) = default ;
230-
231- // Create a ranged generator with the given limits.
232- static auto Generator (Wide min, Wide max) { return Ranged (min, max); }
233-
234- // Create a rangeless generator. Floating point types will leverage the
235- // reinterpretation generator, which is recommended.
236- static auto Generator () {
237- if constexpr (std::is_floating_point_v<D>) {
238- return Reinterpret ();
239- } else {
240- return Ranged ();
241- }
242- }
198+ template <typename D>
199+ using DefaultGenerator =
200+ SelectT<std::is_floating_point<D>,
201+ ReinterpretGenerator<D, std::uniform_real_distribution>,
202+ std::is_integral<D>,
203+ RangedGenerator<D, std::uniform_int_distribution>>;
243204
244- // Initialize a random device with the proper types to work with generators.
245- template <typename ... Args>
246- static auto Device (Args&&... args) {
247- return RandomDevice (std::forward<Args>(args)...);
248- }
205+ template <typename D>
206+ using DefaultRangedGenerator =
207+ SelectT<std::is_floating_point<D>,
208+ RangedGenerator<D, std::uniform_real_distribution>,
209+ std::is_integral<D>,
210+ RangedGenerator<D, std::uniform_int_distribution>>;
249211
250- // Convenience method(s) to create a generator and device in a pair.
251- static auto GeneratorAndDevice () {
252- return std::make_pair (Generator (), Device ());
253- }
254- static auto GeneratorAndDevice (int seed) {
255- return std::make_pair (Generator (), Device (seed));
256- }
257- static auto GeneratorAndDevice (Wide min, Wide max) {
258- return std::make_pair (Generator (min, max), Device ());
259- }
260- static auto GeneratorAndDevice (int seed, Wide min, Wide max) {
261- return std::make_pair (Generator (min, max), Device (seed));
262- }
263- };
212+ using DefaultDevice = RandomDevice<std::mt19937>;
264213
265214} // namespace litert
266215
0 commit comments