|
20 | 20 | #include <ostream> |
21 | 21 | #include <random> |
22 | 22 | #include <type_traits> |
| 23 | +#include <variant> |
| 24 | +#include <vector> |
23 | 25 |
|
24 | 26 | #include "absl/strings/str_cat.h" // from @com_google_absl |
25 | 27 | #include "absl/strings/str_format.h" // from @com_google_absl |
26 | 28 | #include "absl/strings/string_view.h" // from @com_google_absl |
| 29 | +#include "litert/c/litert_common.h" |
| 30 | +#include "litert/c/litert_layout.h" |
| 31 | +#include "litert/c/litert_model.h" |
27 | 32 | #include "litert/cc/litert_detail.h" |
| 33 | +#include "litert/cc/litert_expected.h" |
28 | 34 | #include "litert/cc/litert_numerics.h" |
29 | 35 |
|
30 | 36 | // Various utilities and types for random number generation. |
@@ -112,7 +118,7 @@ class RandomDevice { |
112 | 118 | static constexpr ResultType max() { return Max(); } |
113 | 119 | }; |
114 | 120 |
|
115 | | -// TENSOR DATA GENERATOR /////////////////////////////////////////////////////// |
| 121 | +// PRIMITIVE DATA GENERATORS /////////////////////////////////////////////////// |
116 | 122 |
|
117 | 123 | // Abstract base class for generating data of a certain type from a given rng |
118 | 124 | // device, e.g. populating tensors and the like. |
@@ -211,6 +217,70 @@ using DefaultRangedGenerator = |
211 | 217 |
|
212 | 218 | using DefaultDevice = RandomDevice<std::mt19937>; |
213 | 219 |
|
| 220 | +// RANDOM TENSOR TYPES ///////////////////////////////////////////////////////// |
| 221 | + |
| 222 | +// This class composes the primitive data generators above to support |
| 223 | +// generating randomized tensor types (and shapes). |
| 224 | +class RandomTensorType { |
| 225 | + private: |
| 226 | + using DimSize = uint32_t; |
| 227 | + using DimGenerator = DefaultRangedGenerator<DimSize>; |
| 228 | + using ElementTypeInt = uint8_t; |
| 229 | + using ElementTypeGenerator = DefaultRangedGenerator<ElementTypeInt>; |
| 230 | + |
| 231 | + public: |
| 232 | + using DimRange = std::pair<DimSize, DimSize>; |
| 233 | + using DimSpec = std::variant<DimSize, DimRange>; |
| 234 | + using ElementTypeSpec = std::vector<LiteRtElementType>; |
| 235 | + |
| 236 | + static constexpr auto kMinDim = NumericLimits<DimSize>::Lowest(); |
| 237 | + static constexpr auto kMaxDim = NumericLimits<DimSize>::Max(); |
| 238 | + |
| 239 | + template <typename Rng> |
| 240 | + Expected<LiteRtRankedTensorType> Generate( |
| 241 | + Rng& rng, |
| 242 | + const ElementTypeSpec& type = {kLiteRtElementTypeInt32, |
| 243 | + kLiteRtElementTypeFloat32}, |
| 244 | + const std::vector<std::optional<DimSpec>>& shape_spec = {}) { |
| 245 | + const auto rank = shape_spec.size(); |
| 246 | + if (rank > LITERT_TENSOR_MAX_RANK) { |
| 247 | + return Error(kLiteRtStatusErrorInvalidArgument, "Rank too large"); |
| 248 | + } |
| 249 | + LiteRtRankedTensorType res; |
| 250 | + res.layout.rank = rank; |
| 251 | + res.element_type = GenerateElementType(rng, type); |
| 252 | + for (auto i = 0; i < rank; ++i) { |
| 253 | + res.layout.dimensions[i] = GenerateDim(rng, shape_spec[i]); |
| 254 | + } |
| 255 | + return res; |
| 256 | + } |
| 257 | + |
| 258 | + private: |
| 259 | + template <typename Rng> |
| 260 | + DimSize GenerateDim(Rng& rng, const std::optional<DimSpec>& dim) { |
| 261 | + if (!dim) { |
| 262 | + DimGenerator gen; |
| 263 | + return gen(rng); |
| 264 | + } else if (std::holds_alternative<DimSize>(*dim)) { |
| 265 | + auto d = std::get<DimSize>(*dim); |
| 266 | + return d; |
| 267 | + } else { |
| 268 | + auto d = std::get<DimRange>(*dim); |
| 269 | + DimGenerator gen(d.first, d.second); |
| 270 | + return gen(rng); |
| 271 | + } |
| 272 | + } |
| 273 | + |
| 274 | + template <typename Rng> |
| 275 | + LiteRtElementType GenerateElementType(Rng& rng, const ElementTypeSpec& type) { |
| 276 | + if (type.size() == 1) { |
| 277 | + return type.front(); |
| 278 | + } |
| 279 | + ElementTypeGenerator gen(0, type.size() - 1); |
| 280 | + return type[gen(rng)]; |
| 281 | + } |
| 282 | +}; |
| 283 | + |
214 | 284 | } // namespace litert |
215 | 285 |
|
216 | 286 | #endif // THIRD_PARTY_ODML_LITERT_LITERT_CC_LITERT_RNG_H_ |
0 commit comments