@@ -36,7 +36,9 @@ typedef std::map<
3636 std::type_index,
3737 std::variant<
3838 std::vector<float >,
39- std::vector<double >>>
39+ std::vector<double >,
40+ std::vector<exec_aten::Half>,
41+ std::vector<exec_aten::BFloat16>>>
4042 FloatingTypeToDataMap;
4143
4244typedef std::map<
@@ -381,9 +383,9 @@ TEST_F(OpToDimOrderCopyTest, NanInfSupported) {
381383 ScalarType::OUTPUT_DTYPE>(test_cases);
382384
383385#define TEST_ENTRY (INPUT_CTYPE, INPUT_DTYPE ) \
384- ET_FORALL_FLOAT_TYPES_WITH2 (INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
386+ ET_FORALL_FLOATHBF16_TYPES_WITH2 (INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
385387
386- ET_FORALL_FLOAT_TYPES (TEST_ENTRY);
388+ ET_FORALL_FLOATHBF16_TYPES (TEST_ENTRY);
387389
388390#undef TEST_ENTRY
389391#undef TEST_KERNEL
@@ -413,6 +415,13 @@ TEST_F(OpToDimOrderCopyTest, HardcodeFloatConvertInt) {
413415 -0.30919688936285893988 };
414416 // clang-format on
415417
418+ std::vector<exec_aten::Half> half_data;
419+ std::vector<exec_aten::BFloat16> bf16_data;
420+ for (auto d : double_data) {
421+ half_data.emplace_back (d);
422+ bf16_data.emplace_back (d);
423+ }
424+
416425 std::vector<int64_t > int64_data = {
417426 -1 , -4 , 2 , -2 , 3 , 3 , -3 , -4 , 3 , 3 , 0 , 2 , 0 , -1 , 0 };
418427 std::vector<int32_t > int32_data = {
@@ -426,6 +435,8 @@ TEST_F(OpToDimOrderCopyTest, HardcodeFloatConvertInt) {
426435 FloatingTypeToDataMap floating_point_data;
427436 floating_point_data[typeid (float )] = float_data;
428437 floating_point_data[typeid (double )] = double_data;
438+ floating_point_data[typeid (exec_aten::Half)] = half_data;
439+ floating_point_data[typeid (exec_aten::BFloat16)] = bf16_data;
429440
430441 // Gathering all int data together for better traversial
431442 IntTypeToDataMap int_data;
@@ -444,7 +455,7 @@ TEST_F(OpToDimOrderCopyTest, HardcodeFloatConvertInt) {
444455#define TEST_ENTRY (INPUT_CTYPE, INPUT_DTYPE ) \
445456 ET_FORALL_INT_TYPES_WITH2 (INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
446457
447- ET_FORALL_FLOAT_TYPES (TEST_ENTRY);
458+ ET_FORALL_FLOATHBF16_TYPES (TEST_ENTRY);
448459}
449460
450461TEST_F (OpToDimOrderCopyTest, MismatchedSizesDie) {
0 commit comments