Skip to content

Commit 3121325

Browse files
committed
More readability
1 parent df08213 commit 3121325

File tree

1 file changed

+10
-14
lines changed

1 file changed

+10
-14
lines changed

src/gfloat/round_ndarray.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ def round_ndarray(
2626
is_negative = np.signbit(v) & fi.is_signed
2727
absv = np.where(is_negative, -v, v)
2828

29-
nonzerofinite_mask = ~(np.isnan(v) | np.isinf(v) | (v == 0))
29+
finite_nonzero = ~(np.isnan(v) | np.isinf(v) | (v == 0))
3030

31-
# Place 1.0 where nonzerofinite_mask is False
32-
absv_masked = np.where(nonzerofinite_mask, absv, 1.0)
31+
# Place 1.0 where finite_nonzero is False, to avoid log of {0,inf,nan}
32+
absv_masked = np.where(finite_nonzero, absv, 1.0)
3333

3434
expval = np.floor(np.log2(absv_masked)).astype(int)
3535

@@ -69,23 +69,23 @@ def round_ndarray(
6969
expval += round_up & (isignificand == 1)
7070
isignificand = np.where(round_up, 1, isignificand)
7171

72-
result = np.where(nonzerofinite_mask, isignificand * (2.0**expval), absv)
72+
result = np.where(finite_nonzero, isignificand * (2.0**expval), absv)
7373

7474
amax = np.where(is_negative, -fi.min, fi.max)
7575

7676
if sat:
7777
result = np.where(result > amax, amax, result)
7878
else:
7979
if rnd == RoundMode.TowardNegative:
80-
put_amax_at = (result > amax) & nonzerofinite_mask & ~is_negative
80+
put_amax_at = (result > amax) & ~is_negative
8181
elif rnd == RoundMode.TowardPositive:
82-
put_amax_at = (result > amax) & nonzerofinite_mask & is_negative
82+
put_amax_at = (result > amax) & is_negative
8383
elif rnd == RoundMode.TowardZero:
84-
put_amax_at = (result > amax) & nonzerofinite_mask
84+
put_amax_at = result > amax
8585
else:
8686
put_amax_at = np.zeros_like(result, dtype=bool)
8787

88-
result = np.where(put_amax_at, amax, result)
88+
result = np.where(finite_nonzero & put_amax_at, amax, result)
8989

9090
# Now anything larger than amax goes to infinity or NaN
9191
if fi.has_infs:
@@ -119,7 +119,6 @@ def encode_ndarray(fi: FormatInfo, v: np.ndarray) -> np.ndarray:
119119
vpos = np.where(sign, -v, v)
120120

121121
nan_mask = np.isnan(v)
122-
inf_mask = np.isinf(v)
123122

124123
code = np.zeros_like(v, dtype=np.uint64)
125124

@@ -148,15 +147,12 @@ def encode_ndarray(fi: FormatInfo, v: np.ndarray) -> np.ndarray:
148147
finite_sign = sign[finite_mask]
149148

150149
sig, exp = np.frexp(finite_vpos)
151-
expl = exp.astype(int) - 1
152-
tsig = sig * 2
153150

154-
biased_exp = expl + fi.expBias
151+
biased_exp = exp.astype(np.int64) + (fi.expBias - 1)
155152
subnormal_mask = (biased_exp < 1) & fi.has_subnormals
156153

157-
tsig[subnormal_mask] *= 2.0 ** (biased_exp[subnormal_mask] - 1)
154+
tsig = np.where(subnormal_mask, sig * 2.0**biased_exp, sig * 2 - 1.0)
158155
biased_exp[subnormal_mask] = 0
159-
tsig[~subnormal_mask] -= 1.0
160156

161157
isig = np.floor(tsig * 2**t).astype(int)
162158

0 commit comments

Comments
 (0)