|
12 | 12 | #define MPIR_LSUM(a,b) ((a)+(b)) |
13 | 13 |
|
14 | 14 | static void bfloat16_sum(void *invec, void *inoutvec, MPI_Aint len); |
| 15 | +static void f16_sum(void *invec, void *inoutvec, MPI_Aint len); |
15 | 16 |
|
16 | 17 | void MPIR_SUM(void *invec, void *inoutvec, MPI_Aint * Len, MPI_Datatype * type) |
17 | 18 | { |
@@ -40,6 +41,11 @@ void MPIR_SUM(void *invec, void *inoutvec, MPI_Aint * Len, MPI_Datatype * type) |
40 | 41 | case MPIR_BFLOAT16: |
41 | 42 | bfloat16_sum(invec, inoutvec, len); |
42 | 43 | break; |
| 44 | +#ifndef MPIR_FLOAT16_CTYPE |
| 45 | + case MPIR_FLOAT16: |
| 46 | + f16_sum(invec, inoutvec, len); |
| 47 | + break; |
| 48 | +#endif |
43 | 49 | default: |
44 | 50 | MPIR_Assert(0); |
45 | 51 | break; |
@@ -483,3 +489,39 @@ static void bfloat16_sum(void *invec, void *inoutvec, MPI_Aint len) |
483 | 489 | bfloat16_store((char *) inoutvec + i, a + b); |
484 | 490 | } |
485 | 491 | } |
| 492 | + |
| 493 | +/* IEEE half-precision 16-bit float - software arithemetics |
| 494 | + */ |
| 495 | +static float f16_load(void *p) |
| 496 | +{ |
| 497 | + uint16_t a = *(uint16_t *) p; |
| 498 | + /* expand exponent from 5 bit to 8 bit, fraction from 10 bit to 23 bit */ |
| 499 | + uint32_t u = ((uint32_t) ((a & 0x8000) | ((((a & 0x3c00) >> 10) + 0x70) << 7)) << 16) | |
| 500 | + ((uint32_t) (a & 0x3ff) << 13); |
| 501 | + float v; |
| 502 | + memcpy(&v, &u, sizeof(float)); |
| 503 | + return v; |
| 504 | +} |
| 505 | + |
| 506 | +static void f16_store(void *p, float v) |
| 507 | +{ |
| 508 | + uint32_t u; |
| 509 | + memcpy(&u, &v, sizeof(float)); |
| 510 | + /* shrink exponent from 8 bit to 5 bit, fraction from 23 bit to 10 bit */ |
| 511 | + uint16_t a = ((u & 0x80000000) >> 16) | ((((u & 0x7f800000) >> 23) - 0x70) << 10) | |
| 512 | + ((u & 0x7fffff) >> 16); |
| 513 | + if (u & 0x1000) { |
| 514 | + /* round up */ |
| 515 | + a += 1; |
| 516 | + } |
| 517 | + *(uint16_t *) p = a; |
| 518 | +} |
| 519 | + |
| 520 | +static void f16_sum(void *invec, void *inoutvec, MPI_Aint len) |
| 521 | +{ |
| 522 | + for (MPI_Aint i = 0; i < len * 2; i += 2) { |
| 523 | + float a = f16_load((char *) inoutvec + i); |
| 524 | + float b = f16_load((char *) invec + i); |
| 525 | + f16_store((char *) inoutvec + i, a + b); |
| 526 | + } |
| 527 | +} |
0 commit comments