Skip to content

Commit 326a0ea

Browse files
laramielcopybara-github
authored andcommitted
Add some additional data type tests
PiperOrigin-RevId: 840120632 Change-Id: I8594a5502e6c7790dc49e57d06312bfda4b08656
1 parent 3e0e4b6 commit 326a0ea

File tree

1 file changed

+98
-33
lines changed

1 file changed

+98
-33
lines changed

tensorstore/data_type_test.cc

Lines changed: 98 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <stdint.h>
2121

2222
#include <cmath>
23+
#include <limits>
2324
#include <memory>
2425
#include <string>
2526
#include <type_traits>
@@ -279,8 +280,18 @@ TEST(StaticElementRepresentationDeathTest, UnsignedInt) {
279280
"StaticCast is not valid");
280281
}
281282

283+
using CompareFloatTestTypes =
284+
::testing::Types<float8_e3m4_t, float8_e4m3fn_t, float8_e4m3fnuz_t,
285+
float8_e4m3b11fnuz_t, float8_e5m2_t, float8_e5m2fnuz_t,
286+
float4_e2m1fn_t, bfloat16_t, float16_t, float32_t,
287+
float64_t>;
288+
282289
template <typename T>
283-
void TestCompareIdenticalFloat() {
290+
class CompareFloatTest : public ::testing::Test {};
291+
TYPED_TEST_SUITE(CompareFloatTest, CompareFloatTestTypes);
292+
293+
TYPED_TEST(CompareFloatTest, Identical) {
294+
using T = TypeParam;
284295
const auto compare = [](auto a, auto b) {
285296
return tensorstore::internal_data_type::CompareIdentical(static_cast<T>(a),
286297
static_cast<T>(b));
@@ -290,17 +301,30 @@ void TestCompareIdenticalFloat() {
290301
EXPECT_FALSE(compare(1.0, 2.0));
291302
EXPECT_TRUE(compare(+0.0, +0.0));
292303
EXPECT_TRUE(compare(-0.0, -0.0));
293-
EXPECT_FALSE(compare(+0.0, -0.0));
294-
EXPECT_TRUE(compare(NAN, NAN));
295-
EXPECT_TRUE(compare(INFINITY, INFINITY));
296-
EXPECT_TRUE(compare(-INFINITY, -INFINITY));
297-
EXPECT_FALSE(compare(-INFINITY, INFINITY));
298-
EXPECT_FALSE(compare(NAN, 1));
299-
EXPECT_FALSE(compare(1, NAN));
304+
if constexpr (std::numeric_limits<T>::is_iec559) {
305+
EXPECT_FALSE(compare(+0.0, -0.0));
306+
}
307+
if constexpr (std::numeric_limits<T>::has_quiet_NaN) {
308+
EXPECT_TRUE(compare(NAN, NAN));
309+
EXPECT_FALSE(compare(NAN, 1));
310+
EXPECT_FALSE(compare(1, NAN));
311+
}
312+
if constexpr (std::numeric_limits<T>::has_infinity) {
313+
EXPECT_TRUE(compare(INFINITY, INFINITY));
314+
EXPECT_TRUE(compare(-INFINITY, -INFINITY));
315+
EXPECT_FALSE(compare(-INFINITY, INFINITY));
316+
}
300317
}
301318

319+
using CompareIdenticalComplexTestTypes =
320+
::testing::Types<tensorstore::dtypes::complex64_t,
321+
tensorstore::dtypes::complex128_t>;
302322
template <typename T>
303-
void TestCompareIdenticalComplex() {
323+
class CompareIdenticalComplexTest : public ::testing::Test {};
324+
TYPED_TEST_SUITE(CompareIdenticalComplexTest, CompareIdenticalComplexTestTypes);
325+
326+
TYPED_TEST(CompareIdenticalComplexTest, Basic) {
327+
using T = TypeParam;
304328
using value_type = typename T::value_type;
305329
const auto compare = [](auto ra, auto ia, auto rb, auto ib) {
306330
return tensorstore::internal_data_type::CompareIdentical(
@@ -318,30 +342,6 @@ void TestCompareIdenticalComplex() {
318342
EXPECT_FALSE(compare(1.0, NAN, 1.0, 2.0));
319343
}
320344

321-
TEST(CompareIdenticalTest, Float32) {
322-
TestCompareIdenticalFloat<tensorstore::dtypes::float32_t>();
323-
}
324-
325-
TEST(CompareIdenticalTest, Float64) {
326-
TestCompareIdenticalFloat<tensorstore::dtypes::float64_t>();
327-
}
328-
329-
TEST(CompareIdenticalTest, Bfloat16) {
330-
TestCompareIdenticalFloat<tensorstore::dtypes::bfloat16_t>();
331-
}
332-
333-
TEST(CompareIdenticalTest, Float16) {
334-
TestCompareIdenticalFloat<tensorstore::dtypes::float16_t>();
335-
}
336-
337-
TEST(CompareIdenticalTest, Complex64) {
338-
TestCompareIdenticalComplex<tensorstore::dtypes::complex64_t>();
339-
}
340-
341-
TEST(CompareIdenticalTest, Complex128) {
342-
TestCompareIdenticalComplex<tensorstore::dtypes::complex128_t>();
343-
}
344-
345345
TEST(ElementOperationsTest, FloatCompareIdentical) {
346346
DataType r = dtype_v<unsigned int>;
347347

@@ -455,6 +455,42 @@ TEST(AllocateAndConsructSharedTest, ValueInitialization) {
455455
EXPECT_EQ(0, ptr.get()[1]);
456456
}
457457

458+
using AllocateAndConstructTestTypes =
459+
::testing::Types<int, float, complex128_t, std::string,
460+
std::shared_ptr<int>, int2_t, int4_t, bfloat16_t,
461+
float16_t, float8_e3m4_t, float4_e2m1fn_t, json_t>;
462+
463+
template <typename T>
464+
class AllocateAndConstructTest : public ::testing::Test {
465+
public:
466+
T GetDefaultValue() {
467+
if constexpr (std::is_same_v<T, std::shared_ptr<int>>) {
468+
return std::make_shared<int>();
469+
} else if constexpr (std::is_same_v<T, std::string>) {
470+
return std::string(1000, 'a');
471+
} else {
472+
return T{};
473+
}
474+
}
475+
};
476+
TYPED_TEST_SUITE(AllocateAndConstructTest, AllocateAndConstructTestTypes);
477+
478+
TYPED_TEST(AllocateAndConstructTest, DefaultInitialization) {
479+
auto x = this->GetDefaultValue();
480+
TypeParam* ptr = static_cast<TypeParam*>(tensorstore::AllocateAndConstruct(
481+
1, tensorstore::default_init, tensorstore::dtype_v<TypeParam>));
482+
ptr[0] = std::move(x);
483+
tensorstore::DestroyAndFree(1, tensorstore::dtype_v<TypeParam>, ptr);
484+
}
485+
486+
TYPED_TEST(AllocateAndConstructTest, ValueInitialization) {
487+
TypeParam* ptr = static_cast<TypeParam*>(tensorstore::AllocateAndConstruct(
488+
2, tensorstore::value_init, tensorstore::dtype_v<TypeParam>));
489+
EXPECT_EQ(TypeParam(), ptr[0]);
490+
EXPECT_EQ(TypeParam(), ptr[1]);
491+
tensorstore::DestroyAndFree(2, tensorstore::dtype_v<TypeParam>, ptr);
492+
}
493+
458494
// Thread sanitizer considers `operator new` allocation failure an error, and
459495
// prevents this death test from working.
460496
#if !defined(THREAD_SANITIZER)
@@ -486,6 +522,17 @@ TEST(DataTypeTest, Name) {
486522
EXPECT_EQ("uint32", DataType(dtype_v<uint32_t>).name());
487523
EXPECT_EQ("int64", DataType(dtype_v<int64_t>).name());
488524
EXPECT_EQ("uint64", DataType(dtype_v<uint64_t>).name());
525+
526+
EXPECT_EQ("float8_e3m4", DataType(dtype_v<float8_e3m4_t>).name());
527+
EXPECT_EQ("float8_e4m3fn", DataType(dtype_v<float8_e4m3fn_t>).name());
528+
EXPECT_EQ("float8_e4m3fnuz", DataType(dtype_v<float8_e4m3fnuz_t>).name());
529+
EXPECT_EQ("float8_e4m3b11fnuz",
530+
DataType(dtype_v<float8_e4m3b11fnuz_t>).name());
531+
EXPECT_EQ("float8_e5m2", DataType(dtype_v<float8_e5m2_t>).name());
532+
EXPECT_EQ("float8_e5m2fnuz", DataType(dtype_v<float8_e5m2fnuz_t>).name());
533+
EXPECT_EQ("float4_e2m1fn", DataType(dtype_v<float4_e2m1fn_t>).name());
534+
535+
EXPECT_EQ("bfloat16", DataType(dtype_v<bfloat16_t>).name());
489536
EXPECT_EQ("float16", DataType(dtype_v<float16_t>).name());
490537
EXPECT_EQ("float32", DataType(dtype_v<float32_t>).name());
491538
EXPECT_EQ("float64", DataType(dtype_v<float64_t>).name());
@@ -514,6 +561,15 @@ TEST(DataTypeTest, GetDataType) {
514561
EXPECT_EQ(dtype_v<uint32_t>, GetDataType("uint32"));
515562
EXPECT_EQ(dtype_v<int64_t>, GetDataType("int64"));
516563
EXPECT_EQ(dtype_v<uint64_t>, GetDataType("uint64"));
564+
565+
EXPECT_EQ(dtype_v<float8_e3m4_t>, GetDataType("float8_e3m4"));
566+
EXPECT_EQ(dtype_v<float8_e4m3fn_t>, GetDataType("float8_e4m3fn"));
567+
EXPECT_EQ(dtype_v<float8_e4m3fnuz_t>, GetDataType("float8_e4m3fnuz"));
568+
EXPECT_EQ(dtype_v<float8_e4m3b11fnuz_t>, GetDataType("float8_e4m3b11fnuz"));
569+
EXPECT_EQ(dtype_v<float8_e5m2_t>, GetDataType("float8_e5m2"));
570+
EXPECT_EQ(dtype_v<float8_e5m2fnuz_t>, GetDataType("float8_e5m2fnuz"));
571+
EXPECT_EQ(dtype_v<float4_e2m1fn_t>, GetDataType("float4_e2m1fn"));
572+
517573
EXPECT_EQ(dtype_v<bfloat16_t>, GetDataType("bfloat16"));
518574
EXPECT_EQ(dtype_v<float16_t>, GetDataType("float16"));
519575
EXPECT_EQ(dtype_v<float32_t>, GetDataType("float32"));
@@ -574,6 +630,15 @@ TEST(SerializationTest, Invalid) {
574630
static_assert(IsTrivial<bool>);
575631
static_assert(IsTrivial<int2_t>);
576632
static_assert(IsTrivial<int4_t>);
633+
634+
static_assert(IsTrivial<float8_e3m4_t>);
635+
static_assert(IsTrivial<float8_e4m3fn_t>);
636+
static_assert(IsTrivial<float8_e4m3fnuz_t>);
637+
static_assert(IsTrivial<float8_e4m3b11fnuz_t>);
638+
static_assert(IsTrivial<float8_e5m2_t>);
639+
static_assert(IsTrivial<float8_e5m2fnuz_t>);
640+
static_assert(IsTrivial<float4_e2m1fn_t>);
641+
577642
static_assert(IsTrivial<bfloat16_t>);
578643
static_assert(IsTrivial<float16_t>);
579644
static_assert(IsTrivial<float32_t>);

0 commit comments

Comments
 (0)