|
| 1 | +# Copyright (c) 2024 Graphcore Ltd. All rights reserved. |
| 2 | + |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +from .types import FloatClass, FloatValue, FormatInfo |
| 6 | + |
| 7 | + |
| 8 | +def decode_ndarray(fi: FormatInfo, codes: np.ndarray, np=np) -> np.ndarray: |
| 9 | + r""" |
| 10 | + Vectorized version of :function:`decode_float` |
| 11 | +
|
| 12 | + Args: |
| 13 | + fi (FormatInfo): Floating point format descriptor. |
| 14 | + i (array of int): Integer code points, in the range :math:`0 \le i < 2^{k}`, |
| 15 | + where :math:`k` = ``fi.k`` |
| 16 | +
|
| 17 | + Returns: |
| 18 | + Decoded float values |
| 19 | +
|
| 20 | + Raises: |
| 21 | + ValueError: |
| 22 | + If any :paramref:`i` is outside the range of valid code points in :paramref:`fi`. |
| 23 | + """ |
| 24 | + assert np.issubdtype(codes.dtype, np.integer) |
| 25 | + |
| 26 | + k = fi.k |
| 27 | + p = fi.precision |
| 28 | + t = p - 1 # Trailing significand field width |
| 29 | + num_signbits = 1 if fi.is_signed else 0 |
| 30 | + w = k - t - num_signbits # Exponent field width |
| 31 | + |
| 32 | + if np.any(codes < 0) or np.any(codes >= 2**k): |
| 33 | + raise ValueError(f"Code point not in range [0, 2**{k})") |
| 34 | + |
| 35 | + if fi.is_signed: |
| 36 | + signmask = 1 << (k - 1) |
| 37 | + sign = np.where(codes & signmask, -1.0, 1.0) |
| 38 | + else: |
| 39 | + signmask = None |
| 40 | + sign = 1.0 |
| 41 | + |
| 42 | + exp = (codes >> t) & ((1 << w) - 1) |
| 43 | + significand = codes & ((1 << t) - 1) |
| 44 | + if fi.is_twos_complement: |
| 45 | + significand = np.where(sign < 0, (1 << t) - significand, significand) |
| 46 | + |
| 47 | + expBias = fi.expBias |
| 48 | + |
| 49 | + iszero = (exp == 0) & (significand == 0) if fi.has_zero else False |
| 50 | + issubnormal = (exp == 0) & (significand != 0) if fi.has_subnormals else False |
| 51 | + isnormal = ~iszero & ~issubnormal |
| 52 | + expval = np.where(~isnormal, 1 - expBias, exp - expBias) |
| 53 | + fsignificand = np.where(~isnormal, significand * 2**-t, 1.0 + significand * 2**-t) |
| 54 | + |
| 55 | + # Normal/Subnormal/Zero case, other values will be overwritten |
| 56 | + fval = np.where(iszero, 0.0, sign * fsignificand * 2.0**expval) |
| 57 | + |
| 58 | + # All-bits-special exponent (ABSE) |
| 59 | + if w > 0: |
| 60 | + abse = exp == 2**w - 1 |
| 61 | + min_i_with_nan = 2 ** (p - 1) - fi.num_high_nans |
| 62 | + fval = np.where(abse & (significand >= min_i_with_nan), np.nan, fval) |
| 63 | + if fi.has_infs: |
| 64 | + fval = np.where( |
| 65 | + abse & (significand == min_i_with_nan - 1), np.inf * sign, fval |
| 66 | + ) |
| 67 | + |
| 68 | + # Negative zero |
| 69 | + if fi.has_nz: |
| 70 | + fval = np.where(iszero & (sign < 0), -0.0, fval) |
| 71 | + else: |
| 72 | + # Negative zero slot is nan |
| 73 | + fval = np.where(codes == fi.code_of_negzero, np.nan, fval) |
| 74 | + |
| 75 | + return fval |
0 commit comments