@@ -312,7 +312,7 @@ template <> bool BenchmarkSuite<BS_MPI1>::declare_args(args_parser &parser, std:
312312 set_description (
313313 " The argument after -data_type is a one from possible strings,\n "
314314 " Specifying that type will be used:\n "
315- " byte, char, int, float, double\n "
315+ " byte, char, int, float, double, float16, bfloat16 \n "
316316 " \n "
317317 " Example:\n "
318318 " -data_type char\n "
@@ -323,7 +323,7 @@ template <> bool BenchmarkSuite<BS_MPI1>::declare_args(args_parser &parser, std:
323323 set_description (
324324 " The argument after -red_data_type is a one from possible strings,\n "
325325 " Specifying that type will be used:\n "
326- " char, int, float, double\n "
326+ " char, int, float, double, float16, bfloat16 \n "
327327 " \n "
328328 " Example:\n "
329329 " -red_data_type int\n "
@@ -414,12 +414,24 @@ MPI_Op get_op(MPI_Datatype type) {
414414 MPI_Datatype mpi_int = MPI_INT;
415415 MPI_Datatype mpi_float = MPI_FLOAT;
416416 MPI_Datatype mpi_double = MPI_DOUBLE;
417+ #ifdef MPIX_C_FLOAT16
418+ MPI_Datatype mpi_float16 = MPIX_C_FLOAT16;
419+ #endif
420+ #ifdef MPIX_C_BF16
421+ MPI_Datatype mpi_bfloat16 = MPIX_C_BF16;
422+ #endif
417423 size_t type_size = sizeof (MPI_Datatype);
418424
419425 if (!memcmp (&type, &mpi_char, type_size)) { MPI_Op_create (&(contig_sum<char >), 1 , &op); }
420426 else if (!memcmp (&type, &mpi_int, type_size)) { MPI_Op_create (&(contig_sum<int >), 1 , &op); }
421427 else if (!memcmp (&type, &mpi_float, type_size)) { MPI_Op_create (&(contig_sum<float >), 1 , &op); }
422428 else if (!memcmp (&type, &mpi_double, type_size)) { MPI_Op_create (&(contig_sum<double >), 1 , &op); }
429+ #ifdef MPIX_C_FLOAT16
430+ else if (!memcmp (&type, &mpi_float16, type_size)) { op = MPI_OP_NULL; fprintf (stdout, " \n Warning: contig_type doesn't supported\n " ); }
431+ #endif
432+ #ifdef MPIX_C_BF16
433+ else if (!memcmp (&type, &mpi_bfloat16, type_size)) { op = MPI_OP_NULL; fprintf (stdout, " \n Warning: contig_type doesn't supported \n " ); }
434+ #endif
423435
424436 return op;
425437}
@@ -431,13 +443,25 @@ string type_to_name(MPI_Datatype type) {
431443 MPI_Datatype mpi_int = MPI_INT;
432444 MPI_Datatype mpi_float = MPI_FLOAT;
433445 MPI_Datatype mpi_double = MPI_DOUBLE;
446+ #ifdef MPIX_C_FLOAT16
447+ MPI_Datatype mpi_float16 = MPIX_C_FLOAT16;
448+ #endif
449+ #ifdef MPIX_C_BF16
450+ MPI_Datatype mpi_bfloat16 = MPIX_C_BF16;
451+ #endif
434452 size_t type_size = sizeof (MPI_Datatype);
435453
436454 if (!memcmp (&type, &mpi_byte, type_size)) { name = " MPI_BYTE" ; }
437455 else if (!memcmp (&type, &mpi_char, type_size)) { name = " MPI_CHAR" ; }
438456 else if (!memcmp (&type, &mpi_int, type_size)) { name = " MPI_INT" ; }
439457 else if (!memcmp (&type, &mpi_float, type_size)) { name = " MPI_FLOAT" ; }
440458 else if (!memcmp (&type, &mpi_double, type_size)) { name = " MPI_DOUBLE" ; }
459+ #ifdef MPIX_C_FLOAT16
460+ else if (!memcmp (&type, &mpi_float16, type_size)) { name = " MPIX_C_FLOAT16" ; }
461+ #endif
462+ #ifdef MPIX_C_BF16
463+ else if (!memcmp (&type, &mpi_bfloat16, type_size)) { name = " MPIX_C_BF16" ; }
464+ #endif
441465
442466 return name;
443467}
@@ -602,6 +626,16 @@ template <> bool BenchmarkSuite<BS_MPI1>::prepare(const args_parser &parser, con
602626 } else if (given_data_type == " double" ) {
603627 c_info.s_data_type = MPI_DOUBLE;
604628 c_info.r_data_type = MPI_DOUBLE;
629+ #ifdef MPIX_C_FLOAT16
630+ } else if (given_data_type == " float16" ) {
631+ c_info.s_data_type = MPIX_C_FLOAT16;
632+ c_info.r_data_type = MPIX_C_FLOAT16;
633+ #endif
634+ #ifdef MPIX_C_BF16
635+ } else if (given_data_type == " bfloat16" ) {
636+ c_info.s_data_type = MPIX_C_BF16;
637+ c_info.r_data_type = MPIX_C_BF16;
638+ #endif
605639 } else {
606640 output << " Invalid data_type " << given_data_type << endl;
607641 output << " Set data_type byte" << endl;
@@ -619,6 +653,14 @@ template <> bool BenchmarkSuite<BS_MPI1>::prepare(const args_parser &parser, con
619653 c_info.red_data_type = MPI_FLOAT;
620654 } else if (given_red_data_type == " double" ) {
621655 c_info.red_data_type = MPI_DOUBLE;
656+ #ifdef MPIX_C_FLOAT16
657+ } else if (given_red_data_type == " float16" ) {
658+ c_info.red_data_type = MPIX_C_FLOAT16;
659+ #endif
660+ #ifdef MPIX_C_BF16
661+ } else if (given_red_data_type == " bfloat16" ) {
662+ c_info.red_data_type = MPIX_C_BF16;
663+ #endif
622664 } else {
623665 output << " Invalid red_data_type " << given_red_data_type << endl;
624666 output << " Set red_data_type float" << endl;
0 commit comments