Skip to content

Commit 934340d

Browse files
committed
add ScalarType related tests
1 parent 6e7d15d commit 934340d

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

test/TensorTest.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,5 +170,65 @@ TEST_F(TensorTest, Transpose) {
170170
EXPECT_EQ(transposed.sizes()[2], 2);
171171
}
172172

173+
// 测试 is_complex
174+
TEST_F(TensorTest, IsComplex) {
175+
// Float tensor should not be complex
176+
EXPECT_FALSE(tensor.is_complex());
177+
178+
// Test with actual complex tensor
179+
at::Tensor complex_tensor =
180+
at::ones({2, 3}, at::TensorOptions().dtype(at::kComplexFloat));
181+
EXPECT_TRUE(complex_tensor.is_complex());
182+
183+
at::Tensor complex_double_tensor =
184+
at::ones({2, 3}, at::TensorOptions().dtype(at::kComplexDouble));
185+
EXPECT_TRUE(complex_double_tensor.is_complex());
186+
}
187+
188+
// 测试 is_floating_point
189+
TEST_F(TensorTest, IsFloatingPoint) {
190+
// Float tensor should be floating point
191+
EXPECT_TRUE(tensor.is_floating_point());
192+
193+
// Test with double tensor
194+
at::Tensor double_tensor =
195+
at::ones({2, 3}, at::TensorOptions().dtype(at::kDouble));
196+
EXPECT_TRUE(double_tensor.is_floating_point());
197+
198+
// Test with integer tensor
199+
at::Tensor int_tensor = at::ones({2, 3}, at::TensorOptions().dtype(at::kInt));
200+
EXPECT_FALSE(int_tensor.is_floating_point());
201+
202+
// Test with long tensor
203+
at::Tensor long_tensor =
204+
at::ones({2, 3}, at::TensorOptions().dtype(at::kLong));
205+
EXPECT_FALSE(long_tensor.is_floating_point());
206+
}
207+
208+
// 测试 is_signed
209+
TEST_F(TensorTest, IsSigned) {
210+
// Float tensor should be signed
211+
EXPECT_TRUE(tensor.is_signed());
212+
213+
// Test with int tensor (signed)
214+
at::Tensor int_tensor = at::ones({2, 3}, at::TensorOptions().dtype(at::kInt));
215+
EXPECT_TRUE(int_tensor.is_signed());
216+
217+
// Test with long tensor (signed)
218+
at::Tensor long_tensor =
219+
at::ones({2, 3}, at::TensorOptions().dtype(at::kLong));
220+
EXPECT_TRUE(long_tensor.is_signed());
221+
222+
// Test with byte tensor (unsigned)
223+
at::Tensor byte_tensor =
224+
at::ones({2, 3}, at::TensorOptions().dtype(at::kByte));
225+
EXPECT_FALSE(byte_tensor.is_signed());
226+
227+
// Test with bool tensor (unsigned)
228+
at::Tensor bool_tensor =
229+
at::ones({2, 3}, at::TensorOptions().dtype(at::kBool));
230+
EXPECT_FALSE(bool_tensor.is_signed());
231+
}
232+
173233
} // namespace test
174234
} // namespace at

0 commit comments

Comments
 (0)