Skip to content

Commit 4c66094

Browse files
committed
add ScalarType related tests
1 parent b7d76d2 commit 4c66094

File tree

1 file changed

+106
-0
lines changed

1 file changed

+106
-0
lines changed

test/ScalarTypeTest.cpp

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/core/Tensor.h>
3+
#include <ATen/ops/ones.h>
4+
#include <gtest/gtest.h>
5+
#include <torch/all.h>
6+
7+
#include <string>
8+
#include <vector>
9+
10+
#include "../src/file_manager.h"
11+
12+
extern paddle_api_test::ThreadSafeParam g_custom_param;
13+
14+
namespace at {
15+
namespace test {
16+
17+
using paddle_api_test::FileManerger;
18+
using paddle_api_test::ThreadSafeParam;
19+
20+
class ScalarTypeTest : public ::testing::Test {
21+
protected:
22+
void SetUp() override {
23+
std::vector<int64_t> shape = {2, 3, 4};
24+
tensor = at::ones(shape, at::kFloat);
25+
}
26+
27+
at::Tensor tensor;
28+
};
29+
30+
// 测试 is_complex
31+
TEST_F(ScalarTypeTest, IsComplex) {
32+
auto file_name = g_custom_param.get();
33+
FileManerger file(file_name);
34+
file.createFile();
35+
36+
// Float tensor should not be complex
37+
file << std::to_string(tensor.is_complex()) << " ";
38+
39+
// Test with actual complex tensor
40+
at::Tensor complex_tensor =
41+
at::ones({2, 3}, at::TensorOptions().dtype(at::kComplexFloat));
42+
file << std::to_string(complex_tensor.is_complex()) << " ";
43+
44+
at::Tensor complex_double_tensor =
45+
at::ones({2, 3}, at::TensorOptions().dtype(at::kComplexDouble));
46+
file << std::to_string(complex_double_tensor.is_complex()) << " ";
47+
file.saveFile();
48+
}
49+
50+
// 测试 is_floating_point
51+
TEST_F(ScalarTypeTest, IsFloatingPoint) {
52+
auto file_name = g_custom_param.get();
53+
FileManerger file(file_name);
54+
file.createFile();
55+
56+
// Float tensor should be floating point
57+
file << std::to_string(tensor.is_floating_point()) << " ";
58+
59+
// Test with double tensor
60+
at::Tensor double_tensor =
61+
at::ones({2, 3}, at::TensorOptions().dtype(at::kDouble));
62+
file << std::to_string(double_tensor.is_floating_point()) << " ";
63+
64+
// Test with integer tensor
65+
at::Tensor int_tensor = at::ones({2, 3}, at::TensorOptions().dtype(at::kInt));
66+
file << std::to_string(int_tensor.is_floating_point()) << " ";
67+
68+
// Test with long tensor
69+
at::Tensor long_tensor =
70+
at::ones({2, 3}, at::TensorOptions().dtype(at::kLong));
71+
file << std::to_string(long_tensor.is_floating_point()) << " ";
72+
file.saveFile();
73+
}
74+
75+
// 测试 is_signed
76+
TEST_F(ScalarTypeTest, IsSigned) {
77+
auto file_name = g_custom_param.get();
78+
FileManerger file(file_name);
79+
file.createFile();
80+
81+
// Float tensor should be signed
82+
file << std::to_string(tensor.is_signed()) << " ";
83+
84+
// Test with int tensor (signed)
85+
at::Tensor int_tensor = at::ones({2, 3}, at::TensorOptions().dtype(at::kInt));
86+
file << std::to_string(int_tensor.is_signed()) << " ";
87+
88+
// Test with long tensor (signed)
89+
at::Tensor long_tensor =
90+
at::ones({2, 3}, at::TensorOptions().dtype(at::kLong));
91+
file << std::to_string(long_tensor.is_signed()) << " ";
92+
93+
// Test with byte tensor (unsigned)
94+
at::Tensor byte_tensor =
95+
at::ones({2, 3}, at::TensorOptions().dtype(at::kByte));
96+
file << std::to_string(byte_tensor.is_signed()) << " ";
97+
98+
// Test with bool tensor (unsigned)
99+
at::Tensor bool_tensor =
100+
at::ones({2, 3}, at::TensorOptions().dtype(at::kBool));
101+
file << std::to_string(bool_tensor.is_signed()) << " ";
102+
file.saveFile();
103+
}
104+
105+
} // namespace test
106+
} // namespace at

0 commit comments

Comments
 (0)