Skip to content

Commit aea875d

Browse files
authored
Merge pull request #24 from graphcore-research/simplify-round
Simplify round, fix directed rounding
2 parents b82ad80 + 1ecf411 commit aea875d

File tree

3 files changed

+315
-57
lines changed

3 files changed

+315
-57
lines changed

src/gfloat/round.py

Lines changed: 50 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def round_float(
4040
# Constants
4141
p = fi.precision
4242
bias = fi.expBias
43-
t = p - 1
4443

4544
if np.isnan(v):
4645
if fi.num_nans == 0:
@@ -56,59 +55,57 @@ def round_float(
5655
if np.isinf(vpos):
5756
result = np.inf
5857

59-
elif fi.has_subnormals and vpos < fi.smallest_subnormal / 2:
60-
# Test against smallest_subnormal to avoid subnormals in frexp below
61-
# Note that this restricts us to types narrower than float64
62-
result = 0.0
58+
elif vpos == 0:
59+
result = 0
6360

6461
else:
65-
# Extract significand (mantissa) and exponent
66-
fsignificand, expval = np.frexp(vpos)
67-
assert fsignificand >= 0.5 and fsignificand < 1.0
68-
# Bring significand into range [1.0, 2.0)
69-
fsignificand *= 2
70-
expval -= 1
62+
# Extract exponent
63+
expval = int(np.floor(np.log2(vpos)))
64+
65+
assert expval > -1024 + p # not yet tested for float64 near-subnormals
7166

7267
# Effective precision, accounting for right shift for subnormal values
73-
biased_exp = expval + bias
7468
if fi.has_subnormals:
75-
effective_precision = t + min(biased_exp - 1, 0)
76-
else:
77-
effective_precision = t
69+
expval = max(expval, 1 - bias)
7870

7971
# Lift to "integer * 2^e"
80-
fsignificand *= 2.0**effective_precision
81-
expval -= effective_precision
72+
expval = expval - p + 1
73+
74+
fsignificand = vpos * 2.0**-expval
8275

8376
# Round
8477
isignificand = math.floor(fsignificand)
85-
if isignificand != fsignificand:
86-
# Need to round
87-
if rnd == RoundMode.TowardZero:
88-
pass
89-
elif rnd == RoundMode.TowardPositive:
90-
isignificand += 1 if not sign else 0
91-
elif rnd == RoundMode.TowardNegative:
92-
isignificand += 1 if sign else 0
93-
else:
94-
# Round to nearest
95-
d = fsignificand - isignificand
96-
if d > 0.5:
97-
isignificand += 1
98-
elif d == 0.5:
99-
# Tie
100-
if rnd == RoundMode.TiesToAway:
101-
isignificand += 1
102-
else:
103-
# All other modes tie to even
104-
if fi.precision == 1:
105-
# No significand bits
106-
assert (isignificand == 1) or (isignificand == 0)
107-
if _isodd(biased_exp):
108-
expval += 1
109-
else:
110-
if _isodd(isignificand):
111-
isignificand += 1
78+
delta = fsignificand - isignificand
79+
if (
80+
(rnd == RoundMode.TowardPositive and not sign and delta > 0)
81+
or (rnd == RoundMode.TowardNegative and sign and delta > 0)
82+
or (rnd == RoundMode.TiesToAway and delta >= 0.5)
83+
or (rnd == RoundMode.TiesToEven and delta > 0.5)
84+
or (rnd == RoundMode.TiesToEven and delta == 0.5 and _isodd(isignificand))
85+
):
86+
isignificand += 1
87+
88+
## Special case for Precision=1, all-log format with zero.
89+
if fi.precision == 1:
90+
# The logic is simply duplicated for clarity of reading.
91+
isignificand = math.floor(fsignificand)
92+
code_is_odd = isignificand != 0 and _isodd(expval + bias)
93+
if (
94+
(rnd == RoundMode.TowardPositive and not sign and delta > 0)
95+
or (rnd == RoundMode.TowardNegative and sign and delta > 0)
96+
or (rnd == RoundMode.TiesToAway and delta >= 0.5)
97+
or (rnd == RoundMode.TiesToEven and delta > 0.5)
98+
or (rnd == RoundMode.TiesToEven and delta == 0.5 and code_is_odd)
99+
):
100+
# Go to nextUp.
101+
# Increment isignificand if zero,
102+
# else increment exponent
103+
if isignificand == 0:
104+
isignificand = 1
105+
else:
106+
assert isignificand == 1
107+
expval += 1
108+
## End special case for Precision=1.
112109

113110
result = isignificand * (2.0**expval)
114111

@@ -119,9 +116,15 @@ def round_float(
119116
return 0.0
120117

121118
# Overflow
122-
if result > (-fi.min if sign else fi.max):
123-
if sat:
124-
result = fi.max
119+
amax = -fi.min if sign else fi.max
120+
if result > amax:
121+
if (
122+
sat
123+
or (rnd == RoundMode.TowardNegative and not sign and np.isfinite(v))
124+
or (rnd == RoundMode.TowardPositive and sign and np.isfinite(v))
125+
or (rnd == RoundMode.TowardZero and np.isfinite(v))
126+
):
127+
result = amax
125128
else:
126129
if fi.has_infs:
127130
result = np.inf

src/gfloat/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def bits(self) -> int:
146146
@property
147147
def eps(self) -> float:
148148
"""
149-
The difference between 1.0 and the next smallest representable float
149+
The difference between 1.0 and the smallest representable float
150150
larger than 1.0. For example, for 64-bit binary floats in the IEEE-754
151151
standard, ``eps = 2**-52``, approximately 2.22e-16.
152152
"""
@@ -156,7 +156,7 @@ def eps(self) -> float:
156156
@property
157157
def epsneg(self) -> float:
158158
"""
159-
The difference between 1.0 and the next smallest representable float
159+
The difference between 1.0 and the largest representable float
160160
less than 1.0. For example, for 64-bit binary floats in the IEEE-754
161161
standard, ``epsneg = 2**-53``, approximately 1.11e-16.
162162
"""

0 commit comments

Comments
 (0)