Skip to content

Commit 04de34d

Browse files
LukeBoyercopybara-github
authored andcommitted
Add overload to generate tensor types with random rank. Also set the max generated dimension to a reasonable value.
LiteRT-PiperOrigin-RevId: 775418960
1 parent fd9e47d commit 04de34d

File tree

2 files changed

+34
-13
lines changed

2 files changed

+34
-13
lines changed

litert/cc/litert_rng.h

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -233,15 +233,16 @@ class RandomTensorType {
233233
using DimSpec = std::variant<DimSize, DimRange>;
234234
using ElementTypeSpec = std::vector<LiteRtElementType>;
235235

236-
static constexpr auto kMinDim = NumericLimits<DimSize>::Lowest();
237-
static constexpr auto kMaxDim = NumericLimits<DimSize>::Max();
236+
// Max value of a random dimension.
237+
static constexpr auto kMaxDim = 1024;
238238

239+
// Generate a random tensor type with a pre-determined rank given
240+
// in `shape_spec`.
239241
template <typename Rng>
240242
Expected<LiteRtRankedTensorType> Generate(
241-
Rng& rng,
243+
Rng& rng, const std::vector<std::optional<DimSpec>>& shape_spec,
242244
const ElementTypeSpec& type = {kLiteRtElementTypeInt32,
243-
kLiteRtElementTypeFloat32},
244-
const std::vector<std::optional<DimSpec>>& shape_spec = {}) {
245+
kLiteRtElementTypeFloat32}) {
245246
const auto rank = shape_spec.size();
246247
if (rank > LITERT_TENSOR_MAX_RANK) {
247248
return Error(kLiteRtStatusErrorInvalidArgument, "Rank too large");
@@ -255,12 +256,24 @@ class RandomTensorType {
255256
return res;
256257
}
257258

259+
// Generate a random tensor type with a random rank.
260+
template <typename Rng>
261+
Expected<LiteRtRankedTensorType> Generate(
262+
Rng& rng, size_t max_rank = LITERT_TENSOR_MAX_RANK,
263+
const ElementTypeSpec& type = {kLiteRtElementTypeInt32,
264+
kLiteRtElementTypeFloat32}) {
265+
DimGenerator rank_gen(0, max_rank);
266+
std::vector<std::optional<DimSpec>> shape_spec(rank_gen(rng), std::nullopt);
267+
return Generate(rng, shape_spec, type);
268+
}
269+
258270
private:
259271
template <typename Rng>
260272
DimSize GenerateDim(Rng& rng, const std::optional<DimSpec>& dim) {
261273
if (!dim) {
262-
DimGenerator gen;
263-
return gen(rng);
274+
DimGenerator gen(0, kMaxDim);
275+
const auto res = gen(rng);
276+
return res;
264277
} else if (std::holds_alternative<DimSize>(*dim)) {
265278
auto d = std::get<DimSize>(*dim);
266279
return d;

litert/cc/litert_rng_test.cc

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ TEST_F(LiteRtRngTest, FullySpecifiedRandomTensorType) {
156156
auto device = TracedDevice();
157157
RandomTensorType type;
158158
auto tensor_type = type.Generate(
159-
device, {kLiteRtElementTypeFloat32},
160-
{RandomTensorType::DimSpec(2u), RandomTensorType::DimSpec(2u)});
159+
device, {RandomTensorType::DimSpec(2u), RandomTensorType::DimSpec(2u)},
160+
{kLiteRtElementTypeFloat32});
161161
ASSERT_TRUE(tensor_type);
162162
EXPECT_EQ(tensor_type->element_type, kLiteRtElementTypeFloat32);
163163
EXPECT_EQ(tensor_type->layout.dimensions[0], 2);
@@ -168,7 +168,7 @@ TEST_F(LiteRtRngTest, RandomElementType) {
168168
auto device = TracedDevice();
169169
RandomTensorType type;
170170
auto tensor_type = type.Generate(
171-
device, {kLiteRtElementTypeFloat32, kLiteRtElementTypeInt32});
171+
device, {}, {kLiteRtElementTypeFloat32, kLiteRtElementTypeInt32});
172172
ASSERT_TRUE(tensor_type);
173173
EXPECT_TRUE(tensor_type->element_type == kLiteRtElementTypeFloat32 ||
174174
tensor_type->element_type == kLiteRtElementTypeInt32);
@@ -178,8 +178,8 @@ TEST_F(LiteRtRngTest, RandomTensorShape) {
178178
auto device = TracedDevice();
179179
RandomTensorType type;
180180
auto tensor_type =
181-
type.Generate(device, {kLiteRtElementTypeFloat32},
182-
{RandomTensorType::DimRange(1u, 3u), std::nullopt});
181+
type.Generate(device, {RandomTensorType::DimRange(1u, 3u), std::nullopt},
182+
{kLiteRtElementTypeFloat32});
183183
ASSERT_TRUE(tensor_type);
184184
EXPECT_EQ(tensor_type->element_type, kLiteRtElementTypeFloat32);
185185
EXPECT_EQ(tensor_type->layout.rank, 2);
@@ -188,7 +188,15 @@ TEST_F(LiteRtRngTest, RandomTensorShape) {
188188
EXPECT_LE(dim1, 3u);
189189
const auto dim2 = tensor_type->layout.dimensions[1];
190190
EXPECT_GE(dim2, 0u);
191-
EXPECT_LE(dim2, NumericLimits<uint32_t>::Max());
191+
EXPECT_LE(dim2, RandomTensorType::kMaxDim);
192+
}
193+
194+
TEST_F(LiteRtRngTest, RandomTensorShapeWithRandomRank) {
195+
auto device = TracedDevice();
196+
RandomTensorType type;
197+
auto tensor_type = type.Generate(device, /*max_rank=*/4);
198+
ASSERT_TRUE(tensor_type);
199+
EXPECT_LE(tensor_type->layout.rank, 4);
192200
}
193201

194202
} // namespace

0 commit comments

Comments
 (0)