Skip to content

Commit d2e83a9

Browse files
authored
add layout test (#13)
1 parent 43e5464 commit d2e83a9

File tree

1 file changed

+175
-0
lines changed

1 file changed

+175
-0
lines changed

test/LayoutTest.cpp

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/core/Tensor.h>
3+
#include <ATen/ops/ones.h>
4+
#include <gtest/gtest.h>
5+
#include <torch/all.h>
6+
7+
#include <string>
8+
#include <vector>
9+
10+
#include "../src/file_manager.h"
11+
12+
extern paddle_api_test::ThreadSafeParam g_custom_param;
13+
14+
namespace at {
15+
namespace test {
16+
17+
using paddle_api_test::FileManerger;
18+
using paddle_api_test::ThreadSafeParam;
19+
20+
class LayoutTest : public ::testing::Test {
21+
protected:
22+
void SetUp() override {
23+
std::vector<int64_t> shape = {2, 3, 4};
24+
tensor = at::ones(shape, at::kFloat);
25+
}
26+
27+
at::Tensor tensor;
28+
};
29+
30+
// 测试 layout
31+
TEST_F(LayoutTest, Layout) {
32+
auto file_name = g_custom_param.get();
33+
FileManerger file(file_name);
34+
file.createFile();
35+
36+
// 默认创建的张量应该是 strided 布局
37+
c10::Layout layout = tensor.layout();
38+
file << std::to_string(static_cast<int8_t>(layout)) << " ";
39+
file.saveFile();
40+
}
41+
42+
// 测试 layout 常量别名
43+
TEST_F(LayoutTest, LayoutConstants) {
44+
auto file_name = g_custom_param.get();
45+
FileManerger file(file_name);
46+
file.createFile();
47+
48+
// 测试 c10 命名空间下的常量别名
49+
file << std::to_string(c10::kStrided == c10::Layout::Strided) << " ";
50+
file << std::to_string(c10::kSparse == c10::Layout::Sparse) << " ";
51+
file << std::to_string(c10::kSparseCsr == c10::Layout::SparseCsr) << " ";
52+
file << std::to_string(c10::kSparseCsc == c10::Layout::SparseCsc) << " ";
53+
file << std::to_string(c10::kSparseBsr == c10::Layout::SparseBsr) << " ";
54+
file << std::to_string(c10::kSparseBsc == c10::Layout::SparseBsc) << " ";
55+
file << std::to_string(c10::kMkldnn == c10::Layout::Mkldnn) << " ";
56+
file << std::to_string(c10::kJagged == c10::Layout::Jagged) << " ";
57+
file.saveFile();
58+
}
59+
60+
// 测试 at 命名空间下的 layout 常量
61+
TEST_F(LayoutTest, LayoutConstantsInAtNamespace) {
62+
auto file_name = g_custom_param.get();
63+
FileManerger file(file_name);
64+
file.createFile();
65+
66+
file << std::to_string(at::kStrided == c10::Layout::Strided) << " ";
67+
file << std::to_string(at::kSparse == c10::Layout::Sparse) << " ";
68+
file << std::to_string(at::kSparseCsr == c10::Layout::SparseCsr) << " ";
69+
file << std::to_string(at::kSparseCsc == c10::Layout::SparseCsc) << " ";
70+
file << std::to_string(at::kSparseBsr == c10::Layout::SparseBsr) << " ";
71+
file << std::to_string(at::kSparseBsc == c10::Layout::SparseBsc) << " ";
72+
file << std::to_string(at::kMkldnn == c10::Layout::Mkldnn) << " ";
73+
file << std::to_string(at::kJagged == c10::Layout::Jagged) << " ";
74+
file.saveFile();
75+
}
76+
77+
// 测试 torch 命名空间下的 layout 常量
78+
TEST_F(LayoutTest, LayoutConstantsInTorchNamespace) {
79+
auto file_name = g_custom_param.get();
80+
FileManerger file(file_name);
81+
file.createFile();
82+
83+
file << std::to_string(torch::kStrided == c10::Layout::Strided) << " ";
84+
file << std::to_string(torch::kSparse == c10::Layout::Sparse) << " ";
85+
file << std::to_string(torch::kSparseCsr == c10::Layout::SparseCsr) << " ";
86+
file << std::to_string(torch::kSparseCsc == c10::Layout::SparseCsc) << " ";
87+
file << std::to_string(torch::kSparseBsr == c10::Layout::SparseBsr) << " ";
88+
file << std::to_string(torch::kSparseBsc == c10::Layout::SparseBsc) << " ";
89+
file << std::to_string(torch::kMkldnn == c10::Layout::Mkldnn) << " ";
90+
file << std::to_string(torch::kJagged == c10::Layout::Jagged) << " ";
91+
file.saveFile();
92+
}
93+
94+
// 测试 layout 枚举值
95+
TEST_F(LayoutTest, LayoutEnumValues) {
96+
auto file_name = g_custom_param.get();
97+
FileManerger file(file_name);
98+
file.createFile();
99+
100+
// 测试 Layout 枚举的底层值
101+
file << std::to_string(static_cast<int8_t>(c10::Layout::Strided)) << " ";
102+
file << std::to_string(static_cast<int8_t>(c10::Layout::Sparse)) << " ";
103+
file << std::to_string(static_cast<int8_t>(c10::Layout::SparseCsr)) << " ";
104+
file << std::to_string(static_cast<int8_t>(c10::Layout::Mkldnn)) << " ";
105+
file << std::to_string(static_cast<int8_t>(c10::Layout::SparseCsc)) << " ";
106+
file << std::to_string(static_cast<int8_t>(c10::Layout::SparseBsr)) << " ";
107+
file << std::to_string(static_cast<int8_t>(c10::Layout::SparseBsc)) << " ";
108+
file << std::to_string(static_cast<int8_t>(c10::Layout::Jagged)) << " ";
109+
file << std::to_string(static_cast<int8_t>(c10::Layout::NumOptions)) << " ";
110+
file.saveFile();
111+
}
112+
113+
// 测试 layout 输出流操作符
114+
TEST_F(LayoutTest, LayoutOutputStream) {
115+
auto file_name = g_custom_param.get();
116+
FileManerger file(file_name);
117+
file.createFile();
118+
119+
std::ostringstream oss;
120+
121+
oss.str("");
122+
oss << c10::Layout::Strided;
123+
file << oss.str() << " ";
124+
125+
oss.str("");
126+
oss << c10::Layout::Sparse;
127+
file << oss.str() << " ";
128+
129+
oss.str("");
130+
oss << c10::Layout::SparseCsr;
131+
file << oss.str() << " ";
132+
133+
oss.str("");
134+
oss << c10::Layout::SparseCsc;
135+
file << oss.str() << " ";
136+
137+
oss.str("");
138+
oss << c10::Layout::SparseBsr;
139+
file << oss.str() << " ";
140+
141+
oss.str("");
142+
oss << c10::Layout::SparseBsc;
143+
file << oss.str() << " ";
144+
145+
oss.str("");
146+
oss << c10::Layout::Mkldnn;
147+
file << oss.str() << " ";
148+
149+
oss.str("");
150+
oss << c10::Layout::Jagged;
151+
file << oss.str() << " ";
152+
153+
file.saveFile();
154+
}
155+
156+
// 测试使用 kStrided 常量与 tensor.layout() 比较
157+
TEST_F(LayoutTest, LayoutWithConstant) {
158+
auto file_name = g_custom_param.get();
159+
FileManerger file(file_name);
160+
file.createFile();
161+
162+
// 使用常量别名进行比较
163+
file << std::to_string(tensor.layout() == at::kStrided) << " ";
164+
file << std::to_string(tensor.layout() == torch::kStrided) << " ";
165+
file << std::to_string(tensor.layout() == c10::kStrided) << " ";
166+
167+
// 确保不是其他布局类型
168+
file << std::to_string(tensor.layout() != at::kSparse) << " ";
169+
file << std::to_string(tensor.layout() != at::kSparseCsr) << " ";
170+
file << std::to_string(tensor.layout() != at::kMkldnn) << " ";
171+
file.saveFile();
172+
}
173+
174+
} // namespace test
175+
} // namespace at

0 commit comments

Comments
 (0)