@@ -36,7 +36,9 @@ typedef std::map<
3636 std::type_index,
3737 std::variant<
3838 std::vector<float >,
39- std::vector<double >>>
39+ std::vector<double >,
40+ std::vector<exec_aten::Half>,
41+ std::vector<exec_aten::BFloat16>>>
4042 FloatingTypeToDataMap;
4143
4244typedef std::map<
@@ -309,9 +311,9 @@ TEST_F(OpToTest, AllDtypesSupported) {
309311 ScalarType::OUTPUT_DTYPE>(test_cases);
310312
311313#define TEST_ENTRY (INPUT_CTYPE, INPUT_DTYPE ) \
312- ET_FORALL_REAL_TYPES_WITH2 (INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
314+ ET_FORALL_REALHBF16_TYPES_WITH2 (INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
313315
314- ET_FORALL_REAL_TYPES (TEST_ENTRY);
316+ ET_FORALL_REALHBF16_TYPES (TEST_ENTRY);
315317
316318#undef TEST_ENTRY
317319#undef TEST_KERNEL
@@ -323,14 +325,14 @@ TEST_F(OpToTest, BoolTests) {
323325#define TEST_TO_BOOL (INPUT_CTYPE, INPUT_DTYPE ) \
324326 test_runner_to_bool<INPUT_CTYPE, ScalarType::INPUT_DTYPE>( \
325327 test_case_to_bool, result_to_bool);
326- ET_FORALL_REAL_TYPES (TEST_TO_BOOL);
328+ ET_FORALL_REALHBF16_TYPES (TEST_TO_BOOL);
327329
328330 std::vector<uint8_t > test_case_from_bool = {true , true , false };
329331 std::vector<double > result_from_bool = {1.0 , 1.0 , 0 };
330332#define TEST_FROM_BOOL (OUTPUT_CTYPE, OUTPUT_DTYPE ) \
331333 test_runner_from_bool<OUTPUT_CTYPE, ScalarType::OUTPUT_DTYPE>( \
332334 test_case_from_bool, result_from_bool);
333- ET_FORALL_REAL_TYPES (TEST_FROM_BOOL);
335+ ET_FORALL_REALHBF16_TYPES (TEST_FROM_BOOL);
334336}
335337
336338TEST_F (OpToTest, NanInfSupported) {
@@ -349,9 +351,9 @@ TEST_F(OpToTest, NanInfSupported) {
349351 ScalarType::OUTPUT_DTYPE>(test_cases);
350352
351353#define TEST_ENTRY (INPUT_CTYPE, INPUT_DTYPE ) \
352- ET_FORALL_FLOAT_TYPES_WITH2 (INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
354+ ET_FORALL_FLOATHBF16_TYPES_WITH2 (INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
353355
354- ET_FORALL_FLOAT_TYPES (TEST_ENTRY);
356+ ET_FORALL_FLOATHBF16_TYPES (TEST_ENTRY);
355357
356358#undef TEST_ENTRY
357359#undef TEST_KERNEL
@@ -381,6 +383,13 @@ TEST_F(OpToTest, HardcodeFloatConvertInt) {
381383 -0.30919688936285893988 };
382384 // clang-format on
383385
386+ std::vector<exec_aten::Half> half_data;
387+ std::vector<exec_aten::BFloat16> bf16_data;
388+ for (auto d : double_data) {
389+ half_data.emplace_back (d);
390+ bf16_data.emplace_back (d);
391+ }
392+
384393 std::vector<int64_t > int64_data = {
385394 -1 , -4 , 2 , -2 , 3 , 3 , -3 , -4 , 3 , 3 , 0 , 2 , 0 , -1 , 0 };
386395 std::vector<int32_t > int32_data = {
@@ -394,6 +403,8 @@ TEST_F(OpToTest, HardcodeFloatConvertInt) {
394403 FloatingTypeToDataMap floating_point_data;
395404 floating_point_data[typeid (float )] = float_data;
396405 floating_point_data[typeid (double )] = double_data;
406+ floating_point_data[typeid (exec_aten::Half)] = half_data;
407+ floating_point_data[typeid (exec_aten::BFloat16)] = bf16_data;
397408
398409 // Gathering all int data together for better traversial
399410 IntTypeToDataMap int_data;
@@ -412,7 +423,7 @@ TEST_F(OpToTest, HardcodeFloatConvertInt) {
412423#define TEST_ENTRY (INPUT_CTYPE, INPUT_DTYPE ) \
413424 ET_FORALL_INT_TYPES_WITH2 (INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
414425
415- ET_FORALL_FLOAT_TYPES (TEST_ENTRY);
426+ ET_FORALL_FLOATHBF16_TYPES (TEST_ENTRY);
416427}
417428
418429TEST_F (OpToTest, MismatchedSizesDie) {
0 commit comments