Skip to content

Commit a96314b

Browse files
authored
add SymInt related tests (#18)
1 parent 6b1e21e commit a96314b

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

test/TensorTest.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,5 +213,79 @@ TEST_F(TensorTest, Transpose) {
213213
file.saveFile();
214214
}
215215

216+
// 测试 sym_size
217+
TEST_F(TensorTest, SymSize) {
218+
// 获取符号化的单个维度大小
219+
c10::SymInt sym_size_0 = tensor.sym_size(0);
220+
c10::SymInt sym_size_1 = tensor.sym_size(1);
221+
c10::SymInt sym_size_2 = tensor.sym_size(2);
222+
223+
// 验证符号化大小与实际大小一致
224+
EXPECT_EQ(sym_size_0, 2);
225+
EXPECT_EQ(sym_size_1, 3);
226+
EXPECT_EQ(sym_size_2, 4);
227+
228+
// 测试负索引
229+
c10::SymInt sym_size_neg1 = tensor.sym_size(-1);
230+
EXPECT_EQ(sym_size_neg1, 4);
231+
}
232+
233+
// 测试 sym_stride
234+
TEST_F(TensorTest, SymStride) {
235+
// 获取符号化的单个维度步长
236+
c10::SymInt sym_stride_0 = tensor.sym_stride(0);
237+
c10::SymInt sym_stride_1 = tensor.sym_stride(1);
238+
c10::SymInt sym_stride_2 = tensor.sym_stride(2);
239+
240+
// 验证符号化步长
241+
EXPECT_GT(sym_stride_0, 0);
242+
EXPECT_GT(sym_stride_1, 0);
243+
EXPECT_GT(sym_stride_2, 0);
244+
245+
// 测试负索引
246+
c10::SymInt sym_stride_neg1 = tensor.sym_stride(-1);
247+
EXPECT_EQ(sym_stride_neg1, 1); // 最后一维步长通常为1
248+
}
249+
250+
// 测试 sym_sizes
251+
TEST_F(TensorTest, SymSizes) {
252+
// 获取符号化的所有维度大小
253+
c10::SymIntArrayRef sym_sizes = tensor.sym_sizes();
254+
255+
// 验证维度数量
256+
EXPECT_EQ(sym_sizes.size(), 3U);
257+
258+
// 验证每个维度的大小
259+
EXPECT_EQ(sym_sizes[0], 2);
260+
EXPECT_EQ(sym_sizes[1], 3);
261+
EXPECT_EQ(sym_sizes[2], 4);
262+
}
263+
264+
// 测试 sym_strides
265+
TEST_F(TensorTest, SymStrides) {
266+
// 获取符号化的所有维度步长
267+
c10::SymIntArrayRef sym_strides = tensor.sym_strides();
268+
269+
// 验证维度数量
270+
EXPECT_EQ(sym_strides.size(), 3U);
271+
272+
// 验证步长值都大于0
273+
for (size_t i = 0; i < sym_strides.size(); ++i) {
274+
EXPECT_GT(sym_strides[i], 0);
275+
}
276+
}
277+
278+
// 测试 sym_numel
279+
TEST_F(TensorTest, SymNumel) {
280+
// 获取符号化的元素总数
281+
c10::SymInt sym_numel = tensor.sym_numel();
282+
283+
// 验证符号化元素数与实际元素数一致
284+
EXPECT_EQ(sym_numel, 24); // 2*3*4
285+
286+
// 验证与 numel() 结果一致
287+
EXPECT_EQ(sym_numel, tensor.numel());
288+
}
289+
216290
} // namespace test
217291
} // namespace at

0 commit comments

Comments
 (0)