1111
1212#define MPIR_LSUM (a ,b ) ((a)+(b))
1313
14+ static void bfloat16_sum (void * invec , void * inoutvec , MPI_Aint len );
15+ static void f16_sum (void * invec , void * inoutvec , MPI_Aint len );
16+
1417void MPIR_SUM (void * invec , void * inoutvec , MPI_Aint * Len , MPI_Datatype * type )
1518{
1619 MPI_Aint i , len = * Len ;
@@ -35,6 +38,14 @@ void MPIR_SUM(void *invec, void *inoutvec, MPI_Aint * Len, MPI_Datatype * type)
3538 break; \
3639 }
3740 MPIR_OP_TYPE_GROUP (COMPLEX )
41+ case MPIR_BFLOAT16 :
42+ bfloat16_sum (invec , inoutvec , len );
43+ break ;
44+ #ifndef MPIR_FLOAT16_CTYPE
45+ case MPIR_FLOAT16 :
46+ f16_sum (invec , inoutvec , len );
47+ break ;
48+ #endif
3849 default :
3950 MPIR_Assert (0 );
4051 break ;
@@ -442,3 +453,75 @@ void MPIR_REPLACE(void *invec, void *inoutvec, MPI_Aint * Len, MPI_Datatype * ty
442453 fn_fail :
443454 goto fn_exit ;
444455}
456+
457+ /* -- internal static routines -- */
458+
459+ /* BFloat16 - software arithemetics
460+ * TODO: add hardware support, e.g. via AVX512 intrinsics
461+ */
462+ static float bfloat16_load (void * p )
463+ {
464+ uint32_t u = ((uint32_t ) (* (uint16_t * ) p ) << 16 );
465+ float v ;
466+ memcpy (& v , & u , sizeof (float ));
467+ return v ;
468+ }
469+
470+ static void bfloat16_store (void * p , float v )
471+ {
472+ uint32_t u ;
473+ memcpy (& u , & v , sizeof (float ));
474+ if (u & 0x8000 ) {
475+ /* round up */
476+ * (uint16_t * ) p = (u >> 16 ) + 1 ;
477+ } else {
478+ /* truncation */
479+ * (uint16_t * ) p = (u >> 16 );
480+ }
481+
482+ }
483+
484+ static void bfloat16_sum (void * invec , void * inoutvec , MPI_Aint len )
485+ {
486+ for (MPI_Aint i = 0 ; i < len * 2 ; i += 2 ) {
487+ float a = bfloat16_load ((char * ) inoutvec + i );
488+ float b = bfloat16_load ((char * ) invec + i );
489+ bfloat16_store ((char * ) inoutvec + i , a + b );
490+ }
491+ }
492+
493+ /* IEEE half-precision 16-bit float - software arithemetics
494+ */
495+ static float f16_load (void * p )
496+ {
497+ uint16_t a = * (uint16_t * ) p ;
498+ /* expand exponent from 5 bit to 8 bit, fraction from 10 bit to 23 bit */
499+ uint32_t u = ((uint32_t ) ((a & 0x8000 ) | ((((a & 0x3c00 ) >> 10 ) + 0x70 ) << 7 )) << 16 ) |
500+ ((uint32_t ) (a & 0x3ff ) << 13 );
501+ float v ;
502+ memcpy (& v , & u , sizeof (float ));
503+ return v ;
504+ }
505+
506+ static void f16_store (void * p , float v )
507+ {
508+ uint32_t u ;
509+ memcpy (& u , & v , sizeof (float ));
510+ /* shrink exponent from 8 bit to 5 bit, fraction from 23 bit to 10 bit */
511+ uint16_t a = ((u & 0x80000000 ) >> 16 ) | ((((u & 0x7f800000 ) >> 23 ) - 0x70 ) << 10 ) |
512+ ((u & 0x7fffff ) >> 16 );
513+ if (u & 0x1000 ) {
514+ /* round up */
515+ a += 1 ;
516+ }
517+ * (uint16_t * ) p = a ;
518+ }
519+
520+ static void f16_sum (void * invec , void * inoutvec , MPI_Aint len )
521+ {
522+ for (MPI_Aint i = 0 ; i < len * 2 ; i += 2 ) {
523+ float a = f16_load ((char * ) inoutvec + i );
524+ float b = f16_load ((char * ) invec + i );
525+ f16_store ((char * ) inoutvec + i , a + b );
526+ }
527+ }
0 commit comments