@@ -26,10 +26,10 @@ def round_ndarray(
2626 is_negative = np .signbit (v ) & fi .is_signed
2727 absv = np .where (is_negative , - v , v )
2828
29- nonzerofinite_mask = ~ (np .isnan (v ) | np .isinf (v ) | (v == 0 ))
29+ finite_nonzero = ~ (np .isnan (v ) | np .isinf (v ) | (v == 0 ))
3030
31- # Place 1.0 where nonzerofinite_mask is False
32- absv_masked = np .where (nonzerofinite_mask , absv , 1.0 )
31+ # Place 1.0 where finite_nonzero is False, to avoid log of {0,inf,nan}
32+ absv_masked = np .where (finite_nonzero , absv , 1.0 )
3333
3434 expval = np .floor (np .log2 (absv_masked )).astype (int )
3535
@@ -69,23 +69,23 @@ def round_ndarray(
6969 expval += round_up & (isignificand == 1 )
7070 isignificand = np .where (round_up , 1 , isignificand )
7171
72- result = np .where (nonzerofinite_mask , isignificand * (2.0 ** expval ), absv )
72+ result = np .where (finite_nonzero , isignificand * (2.0 ** expval ), absv )
7373
7474 amax = np .where (is_negative , - fi .min , fi .max )
7575
7676 if sat :
7777 result = np .where (result > amax , amax , result )
7878 else :
7979 if rnd == RoundMode .TowardNegative :
80- put_amax_at = (result > amax ) & nonzerofinite_mask & ~ is_negative
80+ put_amax_at = (result > amax ) & ~ is_negative
8181 elif rnd == RoundMode .TowardPositive :
82- put_amax_at = (result > amax ) & nonzerofinite_mask & is_negative
82+ put_amax_at = (result > amax ) & is_negative
8383 elif rnd == RoundMode .TowardZero :
84- put_amax_at = ( result > amax ) & nonzerofinite_mask
84+ put_amax_at = result > amax
8585 else :
8686 put_amax_at = np .zeros_like (result , dtype = bool )
8787
88- result = np .where (put_amax_at , amax , result )
88+ result = np .where (finite_nonzero & put_amax_at , amax , result )
8989
9090 # Now anything larger than amax goes to infinity or NaN
9191 if fi .has_infs :
@@ -119,7 +119,6 @@ def encode_ndarray(fi: FormatInfo, v: np.ndarray) -> np.ndarray:
119119 vpos = np .where (sign , - v , v )
120120
121121 nan_mask = np .isnan (v )
122- inf_mask = np .isinf (v )
123122
124123 code = np .zeros_like (v , dtype = np .uint64 )
125124
@@ -148,15 +147,12 @@ def encode_ndarray(fi: FormatInfo, v: np.ndarray) -> np.ndarray:
148147 finite_sign = sign [finite_mask ]
149148
150149 sig , exp = np .frexp (finite_vpos )
151- expl = exp .astype (int ) - 1
152- tsig = sig * 2
153150
154- biased_exp = expl + fi .expBias
151+ biased_exp = exp . astype ( np . int64 ) + ( fi .expBias - 1 )
155152 subnormal_mask = (biased_exp < 1 ) & fi .has_subnormals
156153
157- tsig [ subnormal_mask ] *= 2.0 ** ( biased_exp [ subnormal_mask ] - 1 )
154+ tsig = np . where ( subnormal_mask , sig * 2.0 ** biased_exp , sig * 2 - 1.0 )
158155 biased_exp [subnormal_mask ] = 0
159- tsig [~ subnormal_mask ] -= 1.0
160156
161157 isig = np .floor (tsig * 2 ** t ).astype (int )
162158
0 commit comments