Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions test/TensorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,5 +170,79 @@ TEST_F(TensorTest, Transpose) {
EXPECT_EQ(transposed.sizes()[2], 2);
}

// 测试 sym_size
TEST_F(TensorTest, SymSize) {
// 获取符号化的单个维度大小
c10::SymInt sym_size_0 = tensor.sym_size(0);
c10::SymInt sym_size_1 = tensor.sym_size(1);
c10::SymInt sym_size_2 = tensor.sym_size(2);

// 验证符号化大小与实际大小一致
EXPECT_EQ(sym_size_0, 2);
EXPECT_EQ(sym_size_1, 3);
EXPECT_EQ(sym_size_2, 4);

// 测试负索引
c10::SymInt sym_size_neg1 = tensor.sym_size(-1);
EXPECT_EQ(sym_size_neg1, 4);
}

// 测试 sym_stride
TEST_F(TensorTest, SymStride) {
// 获取符号化的单个维度步长
c10::SymInt sym_stride_0 = tensor.sym_stride(0);
c10::SymInt sym_stride_1 = tensor.sym_stride(1);
c10::SymInt sym_stride_2 = tensor.sym_stride(2);

// 验证符号化步长
EXPECT_GT(sym_stride_0, 0);
EXPECT_GT(sym_stride_1, 0);
EXPECT_GT(sym_stride_2, 0);

// 测试负索引
c10::SymInt sym_stride_neg1 = tensor.sym_stride(-1);
EXPECT_EQ(sym_stride_neg1, 1); // 最后一维步长通常为1
}

// 测试 sym_sizes
TEST_F(TensorTest, SymSizes) {
// 获取符号化的所有维度大小
c10::SymIntArrayRef sym_sizes = tensor.sym_sizes();

// 验证维度数量
EXPECT_EQ(sym_sizes.size(), 3U);

// 验证每个维度的大小
EXPECT_EQ(sym_sizes[0], 2);
EXPECT_EQ(sym_sizes[1], 3);
EXPECT_EQ(sym_sizes[2], 4);
}

// 测试 sym_strides
TEST_F(TensorTest, SymStrides) {
// 获取符号化的所有维度步长
c10::SymIntArrayRef sym_strides = tensor.sym_strides();

// 验证维度数量
EXPECT_EQ(sym_strides.size(), 3U);

// 验证步长值都大于0
for (size_t i = 0; i < sym_strides.size(); ++i) {
EXPECT_GT(sym_strides[i], 0);
}
}

// 测试 sym_numel
TEST_F(TensorTest, SymNumel) {
// 获取符号化的元素总数
c10::SymInt sym_numel = tensor.sym_numel();

// 验证符号化元素数与实际元素数一致
EXPECT_EQ(sym_numel, 24); // 2*3*4

// 验证与 numel() 结果一致
EXPECT_EQ(sym_numel, tensor.numel());
}

} // namespace test
} // namespace at