Skip to content

Commit bfb1db7

Browse files
LukeBoyercopybara-github
authored andcommitted
Add random tensor type generator.
LiteRT-PiperOrigin-RevId: 775336901
1 parent 2cc853c commit bfb1db7

File tree

3 files changed

+117
-1
lines changed

3 files changed

+117
-1
lines changed

litert/cc/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,11 @@ cc_library(
597597
],
598598
deps = [
599599
":litert_detail",
600+
":litert_expected",
600601
":litert_numerics",
602+
"//litert/c:litert_common",
603+
"//litert/c:litert_layout",
604+
"//litert/c:litert_model",
601605
"@com_google_absl//absl/strings",
602606
"@com_google_absl//absl/strings:str_format",
603607
"@com_google_absl//absl/strings:string_view",
@@ -610,6 +614,7 @@ cc_test(
610614
deps = [
611615
":litert_numerics",
612616
":litert_rng",
617+
"//litert/c:litert_model_types",
613618
"//litert/test:rng_fixture",
614619
"@com_google_absl//absl/strings:str_format",
615620
"@com_google_absl//absl/strings:string_view",

litert/cc/litert_rng.h

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,17 @@
2020
#include <ostream>
2121
#include <random>
2222
#include <type_traits>
23+
#include <variant>
24+
#include <vector>
2325

2426
#include "absl/strings/str_cat.h" // from @com_google_absl
2527
#include "absl/strings/str_format.h" // from @com_google_absl
2628
#include "absl/strings/string_view.h" // from @com_google_absl
29+
#include "litert/c/litert_common.h"
30+
#include "litert/c/litert_layout.h"
31+
#include "litert/c/litert_model.h"
2732
#include "litert/cc/litert_detail.h"
33+
#include "litert/cc/litert_expected.h"
2834
#include "litert/cc/litert_numerics.h"
2935

3036
// Various utilities and types for random number generation.
@@ -112,7 +118,7 @@ class RandomDevice {
112118
static constexpr ResultType max() { return Max(); }
113119
};
114120

115-
// TENSOR DATA GENERATOR ///////////////////////////////////////////////////////
121+
// PRIMITIVE DATA GENERATORS ///////////////////////////////////////////////////
116122

117123
// Abstract base class for generating data of a certain type from a given rng
118124
// device, e.g. populating tensors and the like.
@@ -211,6 +217,70 @@ using DefaultRangedGenerator =
211217

212218
using DefaultDevice = RandomDevice<std::mt19937>;
213219

220+
// RANDOM TENSOR TYPES /////////////////////////////////////////////////////////
221+
222+
// This class composes the primitive data generators above to support
223+
// generating randomized tensor types (and shapes).
224+
class RandomTensorType {
225+
private:
226+
using DimSize = uint32_t;
227+
using DimGenerator = DefaultRangedGenerator<DimSize>;
228+
using ElementTypeInt = uint8_t;
229+
using ElementTypeGenerator = DefaultRangedGenerator<ElementTypeInt>;
230+
231+
public:
232+
using DimRange = std::pair<DimSize, DimSize>;
233+
using DimSpec = std::variant<DimSize, DimRange>;
234+
using ElementTypeSpec = std::vector<LiteRtElementType>;
235+
236+
static constexpr auto kMinDim = NumericLimits<DimSize>::Lowest();
237+
static constexpr auto kMaxDim = NumericLimits<DimSize>::Max();
238+
239+
template <typename Rng>
240+
Expected<LiteRtRankedTensorType> Generate(
241+
Rng& rng,
242+
const ElementTypeSpec& type = {kLiteRtElementTypeInt32,
243+
kLiteRtElementTypeFloat32},
244+
const std::vector<std::optional<DimSpec>>& shape_spec = {}) {
245+
const auto rank = shape_spec.size();
246+
if (rank > LITERT_TENSOR_MAX_RANK) {
247+
return Error(kLiteRtStatusErrorInvalidArgument, "Rank too large");
248+
}
249+
LiteRtRankedTensorType res;
250+
res.layout.rank = rank;
251+
res.element_type = GenerateElementType(rng, type);
252+
for (auto i = 0; i < rank; ++i) {
253+
res.layout.dimensions[i] = GenerateDim(rng, shape_spec[i]);
254+
}
255+
return res;
256+
}
257+
258+
private:
259+
template <typename Rng>
260+
DimSize GenerateDim(Rng& rng, const std::optional<DimSpec>& dim) {
261+
if (!dim) {
262+
DimGenerator gen;
263+
return gen(rng);
264+
} else if (std::holds_alternative<DimSize>(*dim)) {
265+
auto d = std::get<DimSize>(*dim);
266+
return d;
267+
} else {
268+
auto d = std::get<DimRange>(*dim);
269+
DimGenerator gen(d.first, d.second);
270+
return gen(rng);
271+
}
272+
}
273+
274+
template <typename Rng>
275+
LiteRtElementType GenerateElementType(Rng& rng, const ElementTypeSpec& type) {
276+
if (type.size() == 1) {
277+
return type.front();
278+
}
279+
ElementTypeGenerator gen(0, type.size() - 1);
280+
return type[gen(rng)];
281+
}
282+
};
283+
214284
} // namespace litert
215285

216286
#endif // THIRD_PARTY_ODML_LITERT_LITERT_CC_LITERT_RNG_H_

litert/cc/litert_rng_test.cc

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
#include <cmath>
1919
#include <cstddef>
2020
#include <cstdint>
21+
#include <optional>
2122
#include <random>
2223
#include <type_traits>
2324

2425
#include <gmock/gmock.h>
2526
#include <gtest/gtest.h>
2627
#include "absl/strings/str_format.h" // from @com_google_absl
2728
#include "absl/strings/string_view.h" // from @com_google_absl
29+
#include "litert/c/litert_model.h"
2830
#include "litert/cc/litert_numerics.h"
2931
#include "litert/test/rng_fixture.h"
3032

@@ -150,5 +152,44 @@ TEST_F(LiteRtRngTest, TestWithFuzz) {
150152
}
151153
}
152154

155+
TEST_F(LiteRtRngTest, FullySpecifiedRandomTensorType) {
156+
auto device = TracedDevice();
157+
RandomTensorType type;
158+
auto tensor_type = type.Generate(
159+
device, {kLiteRtElementTypeFloat32},
160+
{RandomTensorType::DimSpec(2u), RandomTensorType::DimSpec(2u)});
161+
ASSERT_TRUE(tensor_type);
162+
EXPECT_EQ(tensor_type->element_type, kLiteRtElementTypeFloat32);
163+
EXPECT_EQ(tensor_type->layout.dimensions[0], 2);
164+
EXPECT_EQ(tensor_type->layout.dimensions[1], 2);
165+
}
166+
167+
TEST_F(LiteRtRngTest, RandomElementType) {
168+
auto device = TracedDevice();
169+
RandomTensorType type;
170+
auto tensor_type = type.Generate(
171+
device, {kLiteRtElementTypeFloat32, kLiteRtElementTypeInt32});
172+
ASSERT_TRUE(tensor_type);
173+
EXPECT_TRUE(tensor_type->element_type == kLiteRtElementTypeFloat32 ||
174+
tensor_type->element_type == kLiteRtElementTypeInt32);
175+
}
176+
177+
TEST_F(LiteRtRngTest, RandomTensorShape) {
178+
auto device = TracedDevice();
179+
RandomTensorType type;
180+
auto tensor_type =
181+
type.Generate(device, {kLiteRtElementTypeFloat32},
182+
{RandomTensorType::DimRange(1u, 3u), std::nullopt});
183+
ASSERT_TRUE(tensor_type);
184+
EXPECT_EQ(tensor_type->element_type, kLiteRtElementTypeFloat32);
185+
EXPECT_EQ(tensor_type->layout.rank, 2);
186+
const auto dim1 = tensor_type->layout.dimensions[0];
187+
EXPECT_GE(dim1, 1u);
188+
EXPECT_LE(dim1, 3u);
189+
const auto dim2 = tensor_type->layout.dimensions[1];
190+
EXPECT_GE(dim2, 0u);
191+
EXPECT_LE(dim2, NumericLimits<uint32_t>::Max());
192+
}
193+
153194
} // namespace
154195
} // namespace litert

0 commit comments

Comments
 (0)