Skip to content

Commit 5719c07

Browse files
committed
add layout test
1 parent 9d35161 commit 5719c07

File tree

1 file changed

+109
-0
lines changed

1 file changed

+109
-0
lines changed

test/TensorTest.cpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <gtest/gtest.h>
55
#include <torch/all.h>
66

7+
#include <sstream>
78
#include <vector>
89

910
namespace at {
@@ -170,5 +171,113 @@ TEST_F(TensorTest, Transpose) {
170171
EXPECT_EQ(transposed.sizes()[2], 2);
171172
}
172173

174+
// 测试 layout
175+
TEST_F(TensorTest, Layout) {
176+
// 默认创建的张量应该是 strided 布局
177+
c10::Layout layout = tensor.layout();
178+
EXPECT_EQ(layout, c10::Layout::Strided);
179+
}
180+
181+
// 测试 layout 常量别名
182+
TEST_F(TensorTest, LayoutConstants) {
183+
// 测试 c10 命名空间下的常量别名
184+
EXPECT_EQ(c10::kStrided, c10::Layout::Strided);
185+
EXPECT_EQ(c10::kSparse, c10::Layout::Sparse);
186+
EXPECT_EQ(c10::kSparseCsr, c10::Layout::SparseCsr);
187+
EXPECT_EQ(c10::kSparseCsc, c10::Layout::SparseCsc);
188+
EXPECT_EQ(c10::kSparseBsr, c10::Layout::SparseBsr);
189+
EXPECT_EQ(c10::kSparseBsc, c10::Layout::SparseBsc);
190+
EXPECT_EQ(c10::kMkldnn, c10::Layout::Mkldnn);
191+
EXPECT_EQ(c10::kJagged, c10::Layout::Jagged);
192+
}
193+
194+
// 测试 at 命名空间下的 layout 常量
195+
TEST_F(TensorTest, LayoutConstantsInAtNamespace) {
196+
EXPECT_EQ(at::kStrided, c10::Layout::Strided);
197+
EXPECT_EQ(at::kSparse, c10::Layout::Sparse);
198+
EXPECT_EQ(at::kSparseCsr, c10::Layout::SparseCsr);
199+
EXPECT_EQ(at::kSparseCsc, c10::Layout::SparseCsc);
200+
EXPECT_EQ(at::kSparseBsr, c10::Layout::SparseBsr);
201+
EXPECT_EQ(at::kSparseBsc, c10::Layout::SparseBsc);
202+
EXPECT_EQ(at::kMkldnn, c10::Layout::Mkldnn);
203+
EXPECT_EQ(at::kJagged, c10::Layout::Jagged);
204+
}
205+
206+
// 测试 torch 命名空间下的 layout 常量
207+
TEST_F(TensorTest, LayoutConstantsInTorchNamespace) {
208+
EXPECT_EQ(torch::kStrided, c10::Layout::Strided);
209+
EXPECT_EQ(torch::kSparse, c10::Layout::Sparse);
210+
EXPECT_EQ(torch::kSparseCsr, c10::Layout::SparseCsr);
211+
EXPECT_EQ(torch::kSparseCsc, c10::Layout::SparseCsc);
212+
EXPECT_EQ(torch::kSparseBsr, c10::Layout::SparseBsr);
213+
EXPECT_EQ(torch::kSparseBsc, c10::Layout::SparseBsc);
214+
EXPECT_EQ(torch::kMkldnn, c10::Layout::Mkldnn);
215+
EXPECT_EQ(torch::kJagged, c10::Layout::Jagged);
216+
}
217+
218+
// 测试 layout 枚举值
219+
TEST_F(TensorTest, LayoutEnumValues) {
220+
// 测试 Layout 枚举的底层值
221+
EXPECT_EQ(static_cast<int8_t>(c10::Layout::Strided), 0);
222+
EXPECT_EQ(static_cast<int8_t>(c10::Layout::Sparse), 1);
223+
EXPECT_EQ(static_cast<int8_t>(c10::Layout::SparseCsr), 2);
224+
EXPECT_EQ(static_cast<int8_t>(c10::Layout::Mkldnn), 3);
225+
EXPECT_EQ(static_cast<int8_t>(c10::Layout::SparseCsc), 4);
226+
EXPECT_EQ(static_cast<int8_t>(c10::Layout::SparseBsr), 5);
227+
EXPECT_EQ(static_cast<int8_t>(c10::Layout::SparseBsc), 6);
228+
EXPECT_EQ(static_cast<int8_t>(c10::Layout::Jagged), 7);
229+
EXPECT_EQ(static_cast<int8_t>(c10::Layout::NumOptions), 8);
230+
}
231+
232+
// 测试 layout 输出流操作符
233+
TEST_F(TensorTest, LayoutOutputStream) {
234+
std::ostringstream oss;
235+
236+
oss.str("");
237+
oss << c10::Layout::Strided;
238+
EXPECT_EQ(oss.str(), "Strided");
239+
240+
oss.str("");
241+
oss << c10::Layout::Sparse;
242+
EXPECT_EQ(oss.str(), "Sparse");
243+
244+
oss.str("");
245+
oss << c10::Layout::SparseCsr;
246+
EXPECT_EQ(oss.str(), "SparseCsr");
247+
248+
oss.str("");
249+
oss << c10::Layout::SparseCsc;
250+
EXPECT_EQ(oss.str(), "SparseCsc");
251+
252+
oss.str("");
253+
oss << c10::Layout::SparseBsr;
254+
EXPECT_EQ(oss.str(), "SparseBsr");
255+
256+
oss.str("");
257+
oss << c10::Layout::SparseBsc;
258+
EXPECT_EQ(oss.str(), "SparseBsc");
259+
260+
oss.str("");
261+
oss << c10::Layout::Mkldnn;
262+
EXPECT_EQ(oss.str(), "Mkldnn");
263+
264+
oss.str("");
265+
oss << c10::Layout::Jagged;
266+
EXPECT_EQ(oss.str(), "Jagged");
267+
}
268+
269+
// 测试使用 kStrided 常量与 tensor.layout() 比较
270+
TEST_F(TensorTest, LayoutWithConstant) {
271+
// 使用常量别名进行比较
272+
EXPECT_EQ(tensor.layout(), at::kStrided);
273+
EXPECT_EQ(tensor.layout(), torch::kStrided);
274+
EXPECT_EQ(tensor.layout(), c10::kStrided);
275+
276+
// 确保不是其他布局类型
277+
EXPECT_NE(tensor.layout(), at::kSparse);
278+
EXPECT_NE(tensor.layout(), at::kSparseCsr);
279+
EXPECT_NE(tensor.layout(), at::kMkldnn);
280+
}
281+
173282
} // namespace test
174283
} // namespace at

0 commit comments

Comments
 (0)