Skip to content

Commit 4cc92dd

Browse files
committed
op: add software SUM support for MPIX_C_FLOAT16
Provide half-precision float sum operation by casting to C float.
1 parent b66d778 commit 4cc92dd

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

src/mpi/coll/op/op_fns.c

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#define MPIR_LSUM(a,b) ((a)+(b))
1313

1414
static void bfloat16_sum(void *invec, void *inoutvec, MPI_Aint len);
15+
static void f16_sum(void *invec, void *inoutvec, MPI_Aint len);
1516

1617
void MPIR_SUM(void *invec, void *inoutvec, MPI_Aint * Len, MPI_Datatype * type)
1718
{
@@ -40,6 +41,11 @@ void MPIR_SUM(void *invec, void *inoutvec, MPI_Aint * Len, MPI_Datatype * type)
4041
case MPIR_BFLOAT16:
4142
bfloat16_sum(invec, inoutvec, len);
4243
break;
44+
#ifndef MPIR_FLOAT16_CTYPE
45+
case MPIR_FLOAT16:
46+
f16_sum(invec, inoutvec, len);
47+
break;
48+
#endif
4349
default:
4450
MPIR_Assert(0);
4551
break;
@@ -483,3 +489,39 @@ static void bfloat16_sum(void *invec, void *inoutvec, MPI_Aint len)
483489
bfloat16_store((char *) inoutvec + i, a + b);
484490
}
485491
}
492+
493+
/* IEEE half-precision 16-bit float - software arithemetics
494+
*/
495+
static float f16_load(void *p)
496+
{
497+
uint16_t a = *(uint16_t *) p;
498+
/* expand exponent from 5 bit to 8 bit, fraction from 10 bit to 23 bit */
499+
uint32_t u = ((uint32_t) ((a & 0x8000) | ((((a & 0x3c00) >> 10) + 0x70) << 7)) << 16) |
500+
((uint32_t) (a & 0x3ff) << 13);
501+
float v;
502+
memcpy(&v, &u, sizeof(float));
503+
return v;
504+
}
505+
506+
static void f16_store(void *p, float v)
507+
{
508+
uint32_t u;
509+
memcpy(&u, &v, sizeof(float));
510+
/* shrink exponent from 8 bit to 5 bit, fraction from 23 bit to 10 bit */
511+
uint16_t a = ((u & 0x80000000) >> 16) | ((((u & 0x7f800000) >> 23) - 0x70) << 10) |
512+
((u & 0x7fffff) >> 16);
513+
if (u & 0x1000) {
514+
/* round up */
515+
a += 1;
516+
}
517+
*(uint16_t *) p = a;
518+
}
519+
520+
static void f16_sum(void *invec, void *inoutvec, MPI_Aint len)
521+
{
522+
for (MPI_Aint i = 0; i < len * 2; i += 2) {
523+
float a = f16_load((char *) inoutvec + i);
524+
float b = f16_load((char *) invec + i);
525+
f16_store((char *) inoutvec + i, a + b);
526+
}
527+
}

0 commit comments

Comments
 (0)