Skip to content

Commit e6590c6

Browse files
committed
FP16/BF16 types with defines
1 parent 72afb1a commit e6590c6

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

src_cpp/MPI1/MPI1_suite.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,16 +414,24 @@ MPI_Op get_op(MPI_Datatype type) {
414414
MPI_Datatype mpi_int = MPI_INT;
415415
MPI_Datatype mpi_float = MPI_FLOAT;
416416
MPI_Datatype mpi_double = MPI_DOUBLE;
417+
#ifdef MPIX_C_FLOAT16
417418
MPI_Datatype mpi_float16 = MPIX_C_FLOAT16;
419+
#endif
420+
#ifdef MPIX_C_BF16
418421
MPI_Datatype mpi_bfloat16 = MPIX_C_BF16;
422+
#endif
419423
size_t type_size = sizeof(MPI_Datatype);
420424

421425
if (!memcmp(&type, &mpi_char, type_size)) { MPI_Op_create(&(contig_sum<char>), 1, &op); }
422426
else if (!memcmp(&type, &mpi_int, type_size)) { MPI_Op_create(&(contig_sum<int>), 1, &op); }
423427
else if (!memcmp(&type, &mpi_float, type_size)) { MPI_Op_create(&(contig_sum<float>), 1, &op); }
424428
else if (!memcmp(&type, &mpi_double, type_size)) { MPI_Op_create(&(contig_sum<double>), 1, &op); }
429+
#ifdef MPIX_C_FLOAT16
425430
else if (!memcmp(&type, &mpi_float16, type_size)) { op = MPI_OP_NULL; fprintf(stdout, "\nWarning: contig_type doesn't supported\n"); }
431+
#endif
432+
#ifdef MPIX_C_BF16
426433
else if (!memcmp(&type, &mpi_bfloat16, type_size)) { op = MPI_OP_NULL; fprintf(stdout, "\nWarning: contig_type doesn't supported \n"); }
434+
#endif
427435

428436
return op;
429437
}
@@ -435,17 +443,25 @@ string type_to_name(MPI_Datatype type) {
435443
MPI_Datatype mpi_int = MPI_INT;
436444
MPI_Datatype mpi_float = MPI_FLOAT;
437445
MPI_Datatype mpi_double = MPI_DOUBLE;
446+
#ifdef MPIX_C_FLOAT16
438447
MPI_Datatype mpi_float16 = MPIX_C_FLOAT16;
448+
#endif
449+
#ifdef MPIX_C_BF16
439450
MPI_Datatype mpi_bfloat16 = MPIX_C_BF16;
451+
#endif
440452
size_t type_size = sizeof(MPI_Datatype);
441453

442454
if (!memcmp(&type, &mpi_byte, type_size)) { name = "MPI_BYTE"; }
443455
else if (!memcmp(&type, &mpi_char, type_size)) { name = "MPI_CHAR"; }
444456
else if (!memcmp(&type, &mpi_int, type_size)) { name = "MPI_INT"; }
445457
else if (!memcmp(&type, &mpi_float, type_size)) { name = "MPI_FLOAT"; }
446458
else if (!memcmp(&type, &mpi_double, type_size)) { name = "MPI_DOUBLE"; }
459+
#ifdef MPIX_C_FLOAT16
447460
else if (!memcmp(&type, &mpi_float16, type_size)) { name = "MPIX_C_FLOAT16"; }
461+
#endif
462+
#ifdef MPIX_C_BF16
448463
else if (!memcmp(&type, &mpi_bfloat16, type_size)) { name = "MPIX_C_BF16"; }
464+
#endif
449465

450466
return name;
451467
}
@@ -610,12 +626,16 @@ template <> bool BenchmarkSuite<BS_MPI1>::prepare(const args_parser &parser, con
610626
} else if (given_data_type == "double") {
611627
c_info.s_data_type = MPI_DOUBLE;
612628
c_info.r_data_type = MPI_DOUBLE;
629+
#ifdef MPIX_C_FLOAT16
613630
} else if (given_data_type == "float16") {
614631
c_info.s_data_type = MPIX_C_FLOAT16;
615632
c_info.r_data_type = MPIX_C_FLOAT16;
633+
#endif
634+
#ifdef MPIX_C_BF16
616635
} else if (given_data_type == "bfloat16") {
617636
c_info.s_data_type = MPIX_C_BF16;
618637
c_info.r_data_type = MPIX_C_BF16;
638+
#endif
619639
} else {
620640
output << "Invalid data_type " << given_data_type << endl;
621641
output << " Set data_type byte" << endl;
@@ -633,10 +653,14 @@ template <> bool BenchmarkSuite<BS_MPI1>::prepare(const args_parser &parser, con
633653
c_info.red_data_type = MPI_FLOAT;
634654
} else if (given_red_data_type == "double") {
635655
c_info.red_data_type = MPI_DOUBLE;
656+
#ifdef MPIX_C_FLOAT16
636657
} else if (given_red_data_type == "float16") {
637658
c_info.red_data_type = MPIX_C_FLOAT16;
659+
#endif
660+
#ifdef MPIX_C_BF16
638661
} else if (given_red_data_type == "bfloat16") {
639662
c_info.red_data_type = MPIX_C_BF16;
663+
#endif
640664
} else {
641665
output << "Invalid red_data_type " << given_red_data_type << endl;
642666
output << " Set red_data_type float" << endl;

0 commit comments

Comments
 (0)