@@ -156,8 +156,8 @@ TEST_F(LiteRtRngTest, FullySpecifiedRandomTensorType) {
156156 auto device = TracedDevice ();
157157 RandomTensorType type;
158158 auto tensor_type = type.Generate (
159- device, {kLiteRtElementTypeFloat32 },
160- {RandomTensorType::DimSpec ( 2u ), RandomTensorType::DimSpec ( 2u ) });
159+ device, {RandomTensorType::DimSpec ( 2u ), RandomTensorType::DimSpec ( 2u ) },
160+ {kLiteRtElementTypeFloat32 });
161161 ASSERT_TRUE (tensor_type);
162162 EXPECT_EQ (tensor_type->element_type , kLiteRtElementTypeFloat32 );
163163 EXPECT_EQ (tensor_type->layout .dimensions [0 ], 2 );
@@ -168,7 +168,7 @@ TEST_F(LiteRtRngTest, RandomElementType) {
168168 auto device = TracedDevice ();
169169 RandomTensorType type;
170170 auto tensor_type = type.Generate (
171- device, {kLiteRtElementTypeFloat32 , kLiteRtElementTypeInt32 });
171+ device, {}, { kLiteRtElementTypeFloat32 , kLiteRtElementTypeInt32 });
172172 ASSERT_TRUE (tensor_type);
173173 EXPECT_TRUE (tensor_type->element_type == kLiteRtElementTypeFloat32 ||
174174 tensor_type->element_type == kLiteRtElementTypeInt32 );
@@ -178,8 +178,8 @@ TEST_F(LiteRtRngTest, RandomTensorShape) {
178178 auto device = TracedDevice ();
179179 RandomTensorType type;
180180 auto tensor_type =
181- type.Generate (device, {kLiteRtElementTypeFloat32 },
182- {RandomTensorType::DimRange ( 1u , 3u ), std:: nullopt });
181+ type.Generate (device, {RandomTensorType::DimRange ( 1u , 3u ), std:: nullopt },
182+ {kLiteRtElementTypeFloat32 });
183183 ASSERT_TRUE (tensor_type);
184184 EXPECT_EQ (tensor_type->element_type , kLiteRtElementTypeFloat32 );
185185 EXPECT_EQ (tensor_type->layout .rank , 2 );
@@ -188,7 +188,15 @@ TEST_F(LiteRtRngTest, RandomTensorShape) {
188188 EXPECT_LE (dim1, 3u );
189189 const auto dim2 = tensor_type->layout .dimensions [1 ];
190190 EXPECT_GE (dim2, 0u );
191- EXPECT_LE (dim2, NumericLimits<uint32_t >::Max ());
191+ EXPECT_LE (dim2, RandomTensorType::kMaxDim );
192+ }
193+
194+ TEST_F (LiteRtRngTest, RandomTensorShapeWithRandomRank) {
195+ auto device = TracedDevice ();
196+ RandomTensorType type;
197+ auto tensor_type = type.Generate (device, /* max_rank=*/ 4 );
198+ ASSERT_TRUE (tensor_type);
199+ EXPECT_LE (tensor_type->layout .rank , 4 );
192200}
193201
194202} // namespace
0 commit comments