Skip to content

Commit 9d4abb3

Browse files
authored
Merge pull request pmodels#7345 from hzhou/2503_logical_types
datatype: support MPI_LOGICAL{1,2,4,8,16} and bfloat16 Approved-by: Lisandro Dalcin Approved-by: Ken Raffenetti
2 parents 3635504 + 57a0702 commit 9d4abb3

File tree

15 files changed

+415
-335
lines changed

15 files changed

+415
-335
lines changed

CHANGES

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ of MPI_INT. Commonly used types include MPI_BYTE, MPI_CHAR, MPI_AINT, use
2222
MPIR_BYTE_INTERNAL, MPIR_CHAR_INTERNAL, MPIR_AINT_INTERNAL instead. There is no
2323
impact to users.
2424

25+
# Added MPI_LOGICAL1, MPI_LOGICAL2, MPI_LOGICAL4, MPI_LOGICAL8, and MPI_LOGICAL16.
26+
27+
# Added MPIX_BFLOAT16, and added software reduction support for MPIX_BFLOAT16
28+
and MPIX_C_FLOAT16.
29+
2530
===============================================================================
2631
Changes in 4.3
2732
===============================================================================

configure.ac

Lines changed: 179 additions & 145 deletions
Large diffs are not rendered by default.

maint/gen_abi.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,6 @@ def gen_mpi_abi_internal_h(out):
6060
elif T == "MPI_Datatype":
6161
idx = int(val, 0) & G.datatype_mask
6262
G.abi_datatypes[idx] = name
63-
if re.match(r'MPI_LOGICAL\d+', name):
64-
G.abi_datatypes[idx] = "MPI_DATATYPE_NULL"
6563
elif T == "MPI_Op":
6664
idx = int(val, 0) & G.op_mask
6765
G.abi_ops[idx] = name

src/include/mpi.h.in

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,14 @@ typedef int MPI_Datatype;
253253
#define MPI_COUNT ((MPI_Datatype)0x4c000845)
254254
/* other extension types */
255255
#define MPIX_C_FLOAT16 ((MPI_Datatype)0x4c000246)
256+
/* Fortran fixed-width logicals */
257+
#define MPI_LOGICAL1 ((MPI_Datatype)0x4c000147)
258+
#define MPI_LOGICAL2 ((MPI_Datatype)0x4c000248)
259+
#define MPI_LOGICAL4 ((MPI_Datatype)0x4c000449)
260+
#define MPI_LOGICAL8 ((MPI_Datatype)0x4c00084a)
261+
#define MPI_LOGICAL16 ((MPI_Datatype)0x4c00104b)
262+
/* other */
263+
#define MPIX_BFLOAT16 ((MPI_Datatype)0x4c00024c)
256264

257265
/* Communicators */
258266
typedef int MPI_Comm;

src/include/mpir_datatype.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
#define MPIR_COMPLEX32 ((MPI_Datatype)0x4c840800)
7373
#define MPIR_COMPLEX64 ((MPI_Datatype)0x4c841000)
7474
#define MPIR_COMPLEX128 ((MPI_Datatype)0x4c842000)
75+
#define MPIR_BFLOAT16 ((MPI_Datatype)0x4c850200) /* bfloat16, use MPIR_TYPE_ALT_FLOAT */
7576
#define MPIR_ALT_FLOAT96 ((MPI_Datatype)0x4c850c00) /* long double (80-bit extended precision) on i386 */
7677
#define MPIR_ALT_FLOAT128 ((MPI_Datatype)0x4c851000) /* long double (80-bit extended precision) on x86-64 */
7778
#define MPIR_ALT_COMPLEX96 ((MPI_Datatype)0x4c861800) /* long double complex on i386 */

src/include/mpir_objects.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ const char *MPIR_Handle_get_kind_str(int kind);
217217
#define MPIR_GROUP_PREALLOC 8
218218
#endif
219219

220-
#define MPIR_DATATYPE_N_BUILTIN 71
220+
#define MPIR_DATATYPE_N_BUILTIN 77 /* 0x4d - must be in sync with mpi.h.in */
221221
#ifdef MPID_DATATYPE_PREALLOC
222222
#define MPIR_DATATYPE_PREALLOC MPID_DATATYPE_PREALLOC
223223
#else

src/mpi/coll/op/op_fns.c

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111

1212
#define MPIR_LSUM(a,b) ((a)+(b))
1313

14+
static void bfloat16_sum(void *invec, void *inoutvec, MPI_Aint len);
15+
static void f16_sum(void *invec, void *inoutvec, MPI_Aint len);
16+
1417
void MPIR_SUM(void *invec, void *inoutvec, MPI_Aint * Len, MPI_Datatype * type)
1518
{
1619
MPI_Aint i, len = *Len;
@@ -35,6 +38,14 @@ void MPIR_SUM(void *invec, void *inoutvec, MPI_Aint * Len, MPI_Datatype * type)
3538
break; \
3639
}
3740
MPIR_OP_TYPE_GROUP(COMPLEX)
41+
case MPIR_BFLOAT16:
42+
bfloat16_sum(invec, inoutvec, len);
43+
break;
44+
#ifndef MPIR_FLOAT16_CTYPE
45+
case MPIR_FLOAT16:
46+
f16_sum(invec, inoutvec, len);
47+
break;
48+
#endif
3849
default:
3950
MPIR_Assert(0);
4051
break;
@@ -442,3 +453,75 @@ void MPIR_REPLACE(void *invec, void *inoutvec, MPI_Aint * Len, MPI_Datatype * ty
442453
fn_fail:
443454
goto fn_exit;
444455
}
456+
457+
/* -- internal static routines -- */
458+
459+
/* BFloat16 - software arithemetics
460+
* TODO: add hardware support, e.g. via AVX512 intrinsics
461+
*/
462+
static float bfloat16_load(void *p)
463+
{
464+
uint32_t u = ((uint32_t) (*(uint16_t *) p) << 16);
465+
float v;
466+
memcpy(&v, &u, sizeof(float));
467+
return v;
468+
}
469+
470+
static void bfloat16_store(void *p, float v)
471+
{
472+
uint32_t u;
473+
memcpy(&u, &v, sizeof(float));
474+
if (u & 0x8000) {
475+
/* round up */
476+
*(uint16_t *) p = (u >> 16) + 1;
477+
} else {
478+
/* truncation */
479+
*(uint16_t *) p = (u >> 16);
480+
}
481+
482+
}
483+
484+
static void bfloat16_sum(void *invec, void *inoutvec, MPI_Aint len)
485+
{
486+
for (MPI_Aint i = 0; i < len * 2; i += 2) {
487+
float a = bfloat16_load((char *) inoutvec + i);
488+
float b = bfloat16_load((char *) invec + i);
489+
bfloat16_store((char *) inoutvec + i, a + b);
490+
}
491+
}
492+
493+
/* IEEE half-precision 16-bit float - software arithemetics
494+
*/
495+
static float f16_load(void *p)
496+
{
497+
uint16_t a = *(uint16_t *) p;
498+
/* expand exponent from 5 bit to 8 bit, fraction from 10 bit to 23 bit */
499+
uint32_t u = ((uint32_t) ((a & 0x8000) | ((((a & 0x3c00) >> 10) + 0x70) << 7)) << 16) |
500+
((uint32_t) (a & 0x3ff) << 13);
501+
float v;
502+
memcpy(&v, &u, sizeof(float));
503+
return v;
504+
}
505+
506+
static void f16_store(void *p, float v)
507+
{
508+
uint32_t u;
509+
memcpy(&u, &v, sizeof(float));
510+
/* shrink exponent from 8 bit to 5 bit, fraction from 23 bit to 10 bit */
511+
uint16_t a = ((u & 0x80000000) >> 16) | ((((u & 0x7f800000) >> 23) - 0x70) << 10) |
512+
((u & 0x7fffff) >> 16);
513+
if (u & 0x1000) {
514+
/* round up */
515+
a += 1;
516+
}
517+
*(uint16_t *) p = a;
518+
}
519+
520+
static void f16_sum(void *invec, void *inoutvec, MPI_Aint len)
521+
{
522+
for (MPI_Aint i = 0; i < len * 2; i += 2) {
523+
float a = f16_load((char *) inoutvec + i);
524+
float b = f16_load((char *) invec + i);
525+
f16_store((char *) inoutvec + i, a + b);
526+
}
527+
}

src/mpi/datatype/typerep/src/typerep_yaksa_init.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ yaksa_type_t MPII_Typerep_get_yaksa_type(MPI_Datatype type)
3030

3131
switch (MPIR_DATATYPE_GET_RAW_INTERNAL(type)) {
3232
case MPIR_INT8:
33+
case MPIR_FORTRAN_LOGICAL8:
3334
yaksa_type = YAKSA_TYPE__INT8_T;
3435
break;
3536

@@ -43,6 +44,7 @@ yaksa_type_t MPII_Typerep_get_yaksa_type(MPI_Datatype type)
4344
break;
4445

4546
case MPIR_INT16:
47+
case MPIR_FORTRAN_LOGICAL16:
4648
yaksa_type = YAKSA_TYPE__INT16_T;
4749
break;
4850

@@ -52,11 +54,13 @@ yaksa_type_t MPII_Typerep_get_yaksa_type(MPI_Datatype type)
5254

5355
case MPIR_FIXED16:
5456
case MPIR_FLOAT16:
57+
case MPIR_BFLOAT16:
5558
case MPIR_COMPLEX8:
5659
yaksa_type = TYPEREP_YAKSA_TYPE__FIXED2;
5760
break;
5861

5962
case MPIR_INT32:
63+
case MPIR_FORTRAN_LOGICAL32:
6064
yaksa_type = YAKSA_TYPE__INT32_T;
6165
break;
6266

@@ -70,6 +74,7 @@ yaksa_type_t MPII_Typerep_get_yaksa_type(MPI_Datatype type)
7074
break;
7175

7276
case MPIR_INT64:
77+
case MPIR_FORTRAN_LOGICAL64:
7378
yaksa_type = YAKSA_TYPE__INT64_T;
7479
break;
7580

@@ -111,6 +116,7 @@ yaksa_type_t MPII_Typerep_get_yaksa_type(MPI_Datatype type)
111116
case MPIR_INT128:
112117
case MPIR_UINT128:
113118
case MPIR_FLOAT128:
119+
case MPIR_FORTRAN_LOGICAL128:
114120
yaksa_type = TYPEREP_YAKSA_TYPE__FIXED16;
115121
break;
116122

src/mpi/datatype/typerep/src/typerep_yaksa_pack_external.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ typedef struct {
7575
} while (0)
7676

7777
/* long double */
78-
#ifdef HAVE_FLOAT128
79-
#define EXTERNAL_LONG_DOUBLE_TYPE __float128
78+
#ifdef MPIR_FLOAT128_CTYPE
79+
#define EXTERNAL_LONG_DOUBLE_TYPE MPIR_FLOAT128_CTYPE
8080
#else
8181
#define EXTERNAL_LONG_DOUBLE_TYPE long double
8282
#endif

src/mpi/datatype/typeutil.c

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ struct MPIR_Datatype_builtin_entry MPIR_Internal_types[] = {
9898
type_name_entry(OFFSET, MULTI), /* 0x44 */
9999
type_name_entry(COUNT, MULTI), /* 0x45 */
100100
type_name_x(C_FLOAT16, FLOATING_POINT), /* 0x46 */
101+
type_name_entry(LOGICAL1, LOGICAL), /* 0x47 */
102+
type_name_entry(LOGICAL2, LOGICAL), /* 0x48 */
103+
type_name_entry(LOGICAL4, LOGICAL), /* 0x49 */
104+
type_name_entry(LOGICAL8, LOGICAL), /* 0x4a */
105+
type_name_entry(LOGICAL16, LOGICAL), /* 0x4b */
106+
type_name_x(BFLOAT16, FLOATING_POINT), /* 0x4c */
101107
/* *INDENT-ON* */
102108
};
103109

@@ -158,34 +164,61 @@ int MPIR_Datatype_builtintype_alignment(MPI_Datatype type)
158164
case MPIR_FIXED8:
159165
case MPIR_INT8:
160166
case MPIR_UINT8:
161-
case MPIR_FLOAT8:
162-
return ALIGNOF_INT8_T;
167+
case MPIR_COMPLEX8:
168+
case MPIR_FORTRAN_LOGICAL8:
169+
return MPIR_INT8_ALIGN;
163170
case MPIR_FIXED16:
164171
case MPIR_INT16:
165172
case MPIR_UINT16:
166-
case MPIR_FLOAT16:
167-
return ALIGNOF_INT16_T;
173+
case MPIR_FORTRAN_LOGICAL16:
174+
return MPIR_INT16_ALIGN;
168175
case MPIR_FIXED32:
169176
case MPIR_INT32:
170177
case MPIR_UINT32:
171-
return ALIGNOF_INT32_T;
178+
case MPIR_FORTRAN_LOGICAL32:
179+
return MPIR_INT32_ALIGN;
172180
case MPIR_FIXED64:
173181
case MPIR_INT64:
174182
case MPIR_UINT64:
175-
return ALIGNOF_INT64_T;
183+
case MPIR_FORTRAN_LOGICAL64:
184+
return MPIR_INT64_ALIGN;
185+
#ifdef MPIR_INT128_ALIGN
186+
case MPIR_FIXED128:
187+
case MPIR_INT128:
188+
case MPIR_UINT128:
189+
case MPIR_FORTRAN_LOGICAL128:
190+
return MPIR_INT128_ALIGN;
191+
#endif
192+
#ifdef MPIR_FLOAT16_ALIGN
193+
case MPIR_FLOAT16:
194+
case MPIR_COMPLEX16:
195+
case MPIR_BFLOAT16:
196+
return MPIR_FLOAT16_ALIGN;
197+
#endif
176198
case MPIR_FLOAT32:
177199
case MPIR_COMPLEX32:
178-
return ALIGNOF_FLOAT;
200+
return MPIR_FLOAT32_ALIGN;
179201
case MPIR_FLOAT64:
180202
case MPIR_COMPLEX64:
181-
return ALIGNOF_DOUBLE;
203+
return MPIR_FLOAT64_ALIGN;
204+
#ifdef MPIR_FLOAT128_ALIGN
205+
case MPIR_FLOAT128:
206+
case MPIR_COMPLEX128:
207+
return MPIR_FLOAT128_ALIGN;
208+
#endif
209+
#ifdef MPIR_ALT_FLOAT96_ALIGN
182210
case MPIR_ALT_FLOAT96:
183-
case MPIR_ALT_FLOAT128:
184211
case MPIR_ALT_COMPLEX96:
212+
return MPIR_ALT_FLOAT96_ALIGN;
213+
#endif
214+
#ifdef MPIR_ALT_FLOAT128_ALIGN
185215
case MPIR_ALT_COMPLEX128:
186-
return ALIGNOF_LONG_DOUBLE;
216+
case MPIR_ALT_FLOAT128:
217+
return MPIR_ALT_FLOAT128_ALIGN;
218+
#endif
187219
default:
188-
/* handle error cases? */
220+
/* FIXME: throw error */
221+
MPIR_Assert(0);
189222
return 1;
190223
}
191224
}

0 commit comments

Comments
 (0)