Skip to content

Commit f3d37fd

Browse files
authored
add TensorAccessor related tests (#25)
1 parent d19ebe4 commit f3d37fd

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed

test/TensorAccessorTest.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/core/Tensor.h>
3+
#include <ATen/ops/ones.h>
4+
#include <gtest/gtest.h>
5+
#include <torch/all.h>
6+
7+
#include <string>
8+
#include <vector>
9+
10+
#include "../src/file_manager.h"
11+
12+
extern paddle_api_test::ThreadSafeParam g_custom_param;
13+
14+
namespace at {
15+
namespace test {
16+
17+
using paddle_api_test::FileManerger;
18+
using paddle_api_test::ThreadSafeParam;
19+
20+
class TensorAccessorTest : public ::testing::Test {
21+
protected:
22+
void SetUp() override {
23+
std::vector<int64_t> shape = {2, 3, 4};
24+
tensor = at::ones(shape, at::kFloat);
25+
}
26+
27+
at::Tensor tensor;
28+
};
29+
30+
// 测试 packed_accessor32
31+
TEST_F(TensorAccessorTest, PackedAccessor32) {
32+
auto file_name = g_custom_param.get();
33+
FileManerger file(file_name);
34+
file.createFile();
35+
auto accessor = tensor.packed_accessor32<float, 3>();
36+
file << std::to_string(accessor.size(0)) << " ";
37+
file << std::to_string(accessor.size(1)) << " ";
38+
file << std::to_string(accessor.size(2)) << " ";
39+
file << std::to_string(accessor[0][0][0]) << " ";
40+
file << std::to_string(accessor[1][2][3]) << " ";
41+
file.saveFile();
42+
}
43+
44+
// 测试 packed_accessor64
45+
TEST_F(TensorAccessorTest, PackedAccessor64) {
46+
auto file_name = g_custom_param.get();
47+
FileManerger file(file_name);
48+
file.createFile();
49+
auto accessor = tensor.packed_accessor64<float, 3>();
50+
file << std::to_string(accessor.size(0)) << " ";
51+
file << std::to_string(accessor.size(1)) << " ";
52+
file << std::to_string(accessor.size(2)) << " ";
53+
file << std::to_string(accessor[0][0][0]) << " ";
54+
file << std::to_string(accessor[1][2][3]) << " ";
55+
file.saveFile();
56+
}
57+
58+
// 测试 generic_packed_accessor
59+
TEST_F(TensorAccessorTest, GenericPackedAccessor) {
60+
auto file_name = g_custom_param.get();
61+
FileManerger file(file_name);
62+
file.createFile();
63+
auto accessor = tensor.generic_packed_accessor<float, 3>();
64+
file << std::to_string(accessor.size(0)) << " ";
65+
file << std::to_string(accessor.size(1)) << " ";
66+
file << std::to_string(accessor.size(2)) << " ";
67+
file << std::to_string(accessor[0][0][0]) << " ";
68+
file << std::to_string(accessor[1][2][3]) << " ";
69+
file.saveFile();
70+
}
71+
72+
// 测试 is_non_overlapping_and_dense
73+
TEST_F(TensorAccessorTest, IsNonOverlappingAndDense) {
74+
auto file_name = g_custom_param.get();
75+
FileManerger file(file_name);
76+
file.createFile();
77+
file << std::to_string(tensor.is_non_overlapping_and_dense()) << " ";
78+
79+
// 测试非连续的tensor
80+
at::Tensor transposed = tensor.transpose(0, 2);
81+
file << std::to_string(transposed.is_non_overlapping_and_dense()) << " ";
82+
83+
// 测试连续化后的tensor
84+
at::Tensor contiguous = transposed.contiguous();
85+
file << std::to_string(contiguous.is_non_overlapping_and_dense()) << " ";
86+
file.saveFile();
87+
}
88+
89+
// 测试 has_names
90+
TEST_F(TensorAccessorTest, HasNames) {
91+
auto file_name = g_custom_param.get();
92+
FileManerger file(file_name);
93+
file.createFile();
94+
file << std::to_string(tensor.has_names()) << " ";
95+
file.saveFile();
96+
}
97+
98+
} // namespace test
99+
} // namespace at

0 commit comments

Comments
 (0)