Skip to content

Commit e669724

Browse files
committed
Add tests, fix logic near zero, and directed
1 parent f263a7a commit e669724

File tree

2 files changed

+262
-7
lines changed

2 files changed

+262
-7
lines changed

src/gfloat/round.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,8 @@ def round_float(
5656
if np.isinf(vpos):
5757
result = np.inf
5858

59-
elif fi.has_subnormals and vpos < fi.smallest_subnormal / 2:
60-
# Test against smallest_subnormal to avoid subnormals in frexp below
61-
# Note that this restricts us to types narrower than float64
62-
result = 0.0
59+
elif vpos == 0:
60+
result = 0
6361

6462
else:
6563
# Extract exponent
@@ -119,9 +117,15 @@ def round_float(
119117
return 0.0
120118

121119
# Overflow
122-
if result > (-fi.min if sign else fi.max):
123-
if sat:
124-
result = fi.max
120+
amax = -fi.min if sign else fi.max
121+
if result > amax:
122+
if (
123+
sat
124+
or (rnd == RoundMode.TowardNegative and not sign and np.isfinite(v))
125+
or (rnd == RoundMode.TowardPositive and sign and np.isfinite(v))
126+
or (rnd == RoundMode.TowardZero and np.isfinite(v))
127+
):
128+
result = amax
125129
else:
126130
if fi.has_infs:
127131
result = np.inf

test/test_round.py

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,257 @@ def test_round_p3109() -> None:
3939
assert round_float(fi, 232.1) == np.inf
4040

4141

42+
p4min = 2**-10 # smallest subnormal in p4
43+
44+
45+
@pytest.mark.parametrize(
46+
"mode, vals",
47+
(
48+
(
49+
RoundMode.TowardZero,
50+
(
51+
(p4min, p4min),
52+
(p4min / 4, 0),
53+
(-p4min, -p4min),
54+
(-p4min / 4, 0.0),
55+
(64.0, 64.0),
56+
(63.0, 60.0),
57+
(62.0, 60.0),
58+
(-64.0, -64.0),
59+
(-63.0, -60.0),
60+
(-62.0, -60.0),
61+
),
62+
),
63+
(
64+
RoundMode.TowardPositive,
65+
(
66+
(p4min, p4min),
67+
(p4min / 4, p4min),
68+
(-p4min, -p4min),
69+
(-p4min / 4, 0.0),
70+
(64.0, 64.0),
71+
(63.0, 64.0),
72+
(62.0, 64.0),
73+
(-64.0, -64.0),
74+
(-63.0, -60.0),
75+
(-62.0, -60.0),
76+
),
77+
),
78+
(
79+
RoundMode.TowardNegative,
80+
(
81+
(p4min, p4min),
82+
(p4min / 4, 0),
83+
(-p4min, -p4min),
84+
(-p4min / 4, -p4min),
85+
(64.0, 64.0),
86+
(63.0, 60.0),
87+
(62.0, 60.0),
88+
(-64.0, -64.0),
89+
(-63.0, -64.0),
90+
(-62.0, -64.0),
91+
),
92+
),
93+
(
94+
RoundMode.TiesToEven,
95+
(
96+
(p4min, p4min),
97+
(p4min / 4, 0),
98+
(p4min / 2, 0),
99+
(-p4min, -p4min),
100+
(-p4min / 4, 0),
101+
(-p4min / 2, 0),
102+
(64.0, 64.0),
103+
(63.0, 64.0),
104+
(62.0, 64.0),
105+
(61.0, 60.0),
106+
(-64.0, -64.0),
107+
(-63.0, -64.0),
108+
(-62.0, -64.0),
109+
(-61.0, -60.0),
110+
(-58.0, -56.0),
111+
),
112+
),
113+
(
114+
RoundMode.TiesToAway,
115+
(
116+
(p4min, p4min),
117+
(p4min / 4, 0),
118+
(p4min / 2, p4min),
119+
(-p4min, -p4min),
120+
(-p4min / 4, 0),
121+
(-p4min / 2, -p4min),
122+
(64.0, 64.0),
123+
(63.0, 64.0),
124+
(62.0, 64.0),
125+
(61.0, 60.0),
126+
(-64.0, -64.0),
127+
(-63.0, -64.0),
128+
(-62.0, -64.0),
129+
(-61.0, -60.0),
130+
(-58.0, -60.0),
131+
),
132+
),
133+
),
134+
)
135+
def test_round_p3109b(mode, vals) -> None:
136+
fi = format_info_p3109(4)
137+
138+
for val, expected in vals:
139+
sat = True
140+
assert round_float(fi, val, mode, sat) == expected
141+
142+
143+
p4max = 224.0
144+
p4maxup = 240.0
145+
p4maxhalfup = (p4max + p4maxup) / 2
146+
147+
148+
@pytest.mark.parametrize(
149+
"modesat, vals",
150+
(
151+
(
152+
(RoundMode.TowardZero, True),
153+
(
154+
(p4max, p4max),
155+
(p4maxhalfup, p4max),
156+
(p4maxup, p4max),
157+
(np.inf, p4max),
158+
(-p4max, -p4max),
159+
(-p4maxhalfup, -p4max),
160+
(-p4maxup, -p4max),
161+
(-np.inf, -p4max),
162+
),
163+
),
164+
(
165+
(RoundMode.TowardZero, False),
166+
(
167+
(p4max, p4max),
168+
(p4maxhalfup, p4max),
169+
(p4maxup, p4max),
170+
(np.inf, np.inf),
171+
(-p4max, -p4max),
172+
(-p4maxhalfup, -p4max),
173+
(-p4maxup, -p4max),
174+
(-np.inf, -np.inf),
175+
),
176+
),
177+
(
178+
(RoundMode.TowardPositive, True),
179+
(
180+
(p4max, p4max),
181+
(p4maxhalfup, p4max),
182+
(p4maxup, p4max),
183+
(np.inf, p4max),
184+
(-p4max, -p4max),
185+
(-p4maxhalfup, -p4max),
186+
(-p4maxup, -p4max),
187+
(-np.inf, -p4max),
188+
),
189+
),
190+
(
191+
(RoundMode.TowardPositive, False),
192+
(
193+
(p4max, p4max),
194+
(p4maxhalfup, np.inf),
195+
(p4maxup, np.inf),
196+
(np.inf, np.inf),
197+
(-p4max, -p4max),
198+
(-p4maxhalfup, -p4max),
199+
(-p4maxup, -p4max),
200+
(-np.inf, -np.inf),
201+
),
202+
),
203+
(
204+
(RoundMode.TowardNegative, True),
205+
(
206+
(p4max, p4max),
207+
(p4maxhalfup, p4max),
208+
(p4maxup, p4max),
209+
(np.inf, p4max),
210+
(-p4max, -p4max),
211+
(-p4maxhalfup, -p4max),
212+
(-p4maxup, -p4max),
213+
(-np.inf, -p4max),
214+
),
215+
),
216+
(
217+
(RoundMode.TowardNegative, False),
218+
(
219+
(p4max, p4max),
220+
(p4maxhalfup, p4max),
221+
(p4maxup, p4max),
222+
(np.inf, np.inf),
223+
(-p4max, -p4max),
224+
(-p4maxhalfup, -np.inf),
225+
(-p4maxup, -np.inf),
226+
(-np.inf, -np.inf),
227+
),
228+
),
229+
(
230+
(RoundMode.TiesToEven, True),
231+
(
232+
(p4max, p4max),
233+
(p4maxhalfup, p4max),
234+
(p4maxup, p4max),
235+
(np.inf, p4max),
236+
(-p4max, -p4max),
237+
(-p4maxhalfup, -p4max),
238+
(-p4maxup, -p4max),
239+
(-np.inf, -p4max),
240+
),
241+
),
242+
(
243+
(RoundMode.TiesToEven, False),
244+
(
245+
(p4max, p4max),
246+
(p4maxhalfup, p4max),
247+
(p4maxup, np.inf),
248+
(np.inf, np.inf),
249+
(-p4max, -p4max),
250+
(-p4maxhalfup, -p4max),
251+
(-p4maxup, -np.inf),
252+
(-np.inf, -np.inf),
253+
),
254+
),
255+
(
256+
(RoundMode.TiesToAway, True),
257+
(
258+
(p4max, p4max),
259+
(p4maxhalfup, p4max),
260+
(p4maxup, p4max),
261+
(np.inf, p4max),
262+
(-p4max, -p4max),
263+
(-p4maxhalfup, -p4max),
264+
(-p4maxup, -p4max),
265+
(-np.inf, -p4max),
266+
),
267+
),
268+
(
269+
(RoundMode.TiesToAway, False),
270+
(
271+
(p4max, p4max),
272+
(p4maxhalfup, np.inf),
273+
(p4maxup, np.inf),
274+
(np.inf, np.inf),
275+
(-p4max, -p4max),
276+
(-p4maxhalfup, -np.inf),
277+
(-p4maxup, -np.inf),
278+
(-np.inf, -np.inf),
279+
),
280+
),
281+
),
282+
ids=lambda x: (
283+
f"{str(x[0])}-{'Sat' if x[1] else 'Inf'}" if len(x) == 2 else f"{len(x)}"
284+
),
285+
)
286+
def test_round_p3109_sat(modesat, vals) -> None:
287+
fi = format_info_p3109(4)
288+
289+
for val, expected in vals:
290+
assert round_float(fi, val, *modesat) == expected
291+
292+
42293
def test_round_e5m2() -> None:
43294
fi = format_info_ocp_e5m2
44295

0 commit comments

Comments
 (0)