Skip to content

Commit 0f9e6c3

Browse files
committed
mpi: add MPIX_BFLOAT16
This serves as an example how we add a new builtin mpi datatype. 1. define the constant in mpi.h.in 2. (optional) define the internal datatype in mpir_datatype.h if there isn't one already 2a. add alignment in MPIR_Datatype_builtintype_alignment 2b. add mapping in MPII_Typerep_get_yaksa_type 3. define the mapping in configure.ac 4. (optional) define case for the supported reduction op
1 parent f8ae768 commit 0f9e6c3

File tree

7 files changed

+50
-2
lines changed

7 files changed

+50
-2
lines changed

configure.ac

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3576,6 +3576,7 @@ AC_DEFINE_UNQUOTED([MPIR_LOGICAL2_INTERNAL], [MPIR_FORTRAN_LOGICAL16], [Interna
35763576
AC_DEFINE_UNQUOTED([MPIR_LOGICAL4_INTERNAL], [MPIR_FORTRAN_LOGICAL32], [Internal type for MPI_LOGICAL4])
35773577
AC_DEFINE_UNQUOTED([MPIR_LOGICAL8_INTERNAL], [MPIR_FORTRAN_LOGICAL64], [Internal type for MPI_LOGICAL8])
35783578
AC_DEFINE_UNQUOTED([MPIR_LOGICAL16_INTERNAL], [MPIR_FORTRAN_LOGICAL128],[Internal type for MPI_LOGICAL16])
3579+
AC_DEFINE_UNQUOTED([MPIR_BFLOAT16_INTERNAL], [MPIR_BFLOAT16],[Internal type for MPIX_BFLOAT16])
35793580
AC_MSG_RESULT([done])
35803581

35813582
# ------------------------------------------------------------------------

src/include/mpi.h.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,8 @@ typedef int MPI_Datatype;
259259
#define MPI_LOGICAL4 ((MPI_Datatype)0x4c000449)
260260
#define MPI_LOGICAL8 ((MPI_Datatype)0x4c00084a)
261261
#define MPI_LOGICAL16 ((MPI_Datatype)0x4c00104b)
262+
/* other */
263+
#define MPIX_BFLOAT16 ((MPI_Datatype)0x4c00024c)
262264

263265
/* Communicators */
264266
typedef int MPI_Comm;

src/include/mpir_datatype.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
#define MPIR_COMPLEX32 ((MPI_Datatype)0x4c840800)
7373
#define MPIR_COMPLEX64 ((MPI_Datatype)0x4c841000)
7474
#define MPIR_COMPLEX128 ((MPI_Datatype)0x4c842000)
75+
#define MPIR_BFLOAT16 ((MPI_Datatype)0x4c850200) /* bfloat16, use MPIR_TYPE_ALT_FLOAT */
7576
#define MPIR_ALT_FLOAT96 ((MPI_Datatype)0x4c850c00) /* long double (80-bit extended precision) on i386 */
7677
#define MPIR_ALT_FLOAT128 ((MPI_Datatype)0x4c851000) /* long double (80-bit extended precision) on x86-64 */
7778
#define MPIR_ALT_COMPLEX96 ((MPI_Datatype)0x4c861800) /* long double complex on i386 */

src/include/mpir_objects.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ const char *MPIR_Handle_get_kind_str(int kind);
217217
#define MPIR_GROUP_PREALLOC 8
218218
#endif
219219

220-
#define MPIR_DATATYPE_N_BUILTIN 76 /* 0x4c - must be in sync with mpi.h.in */
220+
#define MPIR_DATATYPE_N_BUILTIN 77 /* 0x4d - must be in sync with mpi.h.in */
221221
#ifdef MPID_DATATYPE_PREALLOC
222222
#define MPIR_DATATYPE_PREALLOC MPID_DATATYPE_PREALLOC
223223
#else

src/mpi/coll/op/op_fns.c

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
#define MPIR_LSUM(a,b) ((a)+(b))
1313

14+
static void bfloat16_sum(void *invec, void *inoutvec, MPI_Aint len);
15+
1416
void MPIR_SUM(void *invec, void *inoutvec, MPI_Aint * Len, MPI_Datatype * type)
1517
{
1618
MPI_Aint i, len = *Len;
@@ -35,6 +37,9 @@ void MPIR_SUM(void *invec, void *inoutvec, MPI_Aint * Len, MPI_Datatype * type)
3537
break; \
3638
}
3739
MPIR_OP_TYPE_GROUP(COMPLEX)
40+
case MPIR_BFLOAT16:
41+
bfloat16_sum(invec, inoutvec, len);
42+
break;
3843
default:
3944
MPIR_Assert(0);
4045
break;
@@ -442,3 +447,39 @@ void MPIR_REPLACE(void *invec, void *inoutvec, MPI_Aint * Len, MPI_Datatype * ty
442447
fn_fail:
443448
goto fn_exit;
444449
}
450+
451+
/* -- internal static routines -- */
452+
453+
/* BFloat16 - software arithemetics
454+
* TODO: add hardware support, e.g. via AVX512 intrinsics
455+
*/
456+
static float bfloat16_load(void *p)
457+
{
458+
uint32_t u = ((uint32_t) (*(uint16_t *) p) << 16);
459+
float v;
460+
memcpy(&v, &u, sizeof(float));
461+
return v;
462+
}
463+
464+
static void bfloat16_store(void *p, float v)
465+
{
466+
uint32_t u;
467+
memcpy(&u, &v, sizeof(float));
468+
if (u & 0x8000) {
469+
/* round up */
470+
*(uint16_t *) p = (u >> 16) + 1;
471+
} else {
472+
/* truncation */
473+
*(uint16_t *) p = (u >> 16);
474+
}
475+
476+
}
477+
478+
static void bfloat16_sum(void *invec, void *inoutvec, MPI_Aint len)
479+
{
480+
for (MPI_Aint i = 0; i < len * 2; i += 2) {
481+
float a = bfloat16_load((char *) inoutvec + i);
482+
float b = bfloat16_load((char *) invec + i);
483+
bfloat16_store((char *) inoutvec + i, a + b);
484+
}
485+
}

src/mpi/datatype/typerep/src/typerep_yaksa_init.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ yaksa_type_t MPII_Typerep_get_yaksa_type(MPI_Datatype type)
5252

5353
case MPIR_FIXED16:
5454
case MPIR_FLOAT16:
55+
case MPIR_BFLOAT16:
5556
case MPIR_COMPLEX8:
5657
yaksa_type = TYPEREP_YAKSA_TYPE__FIXED2;
5758
break;

src/mpi/datatype/typeutil.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ struct MPIR_Datatype_builtin_entry MPIR_Internal_types[] = {
103103
type_name_entry(LOGICAL4, LOGICAL), /* 0x49 */
104104
type_name_entry(LOGICAL8, LOGICAL), /* 0x4a */
105105
type_name_entry(LOGICAL16, LOGICAL), /* 0x4b */
106+
type_name_x(BFLOAT16, FLOATING_POINT), /* 0x4c */
106107
/* *INDENT-ON* */
107108
};
108109

@@ -169,6 +170,7 @@ int MPIR_Datatype_builtintype_alignment(MPI_Datatype type)
169170
case MPIR_INT16:
170171
case MPIR_UINT16:
171172
case MPIR_FLOAT16:
173+
case MPIR_BFLOAT16:
172174
return ALIGNOF_INT16_T;
173175
case MPIR_FIXED32:
174176
case MPIR_INT32:
@@ -190,7 +192,7 @@ int MPIR_Datatype_builtintype_alignment(MPI_Datatype type)
190192
case MPIR_ALT_COMPLEX128:
191193
return ALIGNOF_LONG_DOUBLE;
192194
default:
193-
/* handle error cases? */
195+
MPIR_Assert(0);
194196
return 1;
195197
}
196198
}

0 commit comments

Comments
 (0)