Skip to content

Commit 9625ca1

Browse files
committed
Add SR on arrays, add visualization on notebook
1 parent 03b3324 commit 9625ca1

File tree

3 files changed

+105
-42
lines changed

3 files changed

+105
-42
lines changed

docs/source/05-stochastic-rounding.ipynb

Lines changed: 61 additions & 16 deletions
Large diffs are not rendered by default.

src/gfloat/round_ndarray.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
22

3+
from typing import Optional
34
from types import ModuleType
45
from .types import FormatInfo, RoundMode
56
import numpy as np
@@ -15,6 +16,8 @@ def round_ndarray(
1516
v: np.ndarray,
1617
rnd: RoundMode = RoundMode.TiesToEven,
1718
sat: bool = False,
19+
srbits: Optional[np.ndarray] = None,
20+
srnumbits: int = 0,
1821
np: ModuleType = np,
1922
) -> np.ndarray:
2023
"""
@@ -70,18 +73,24 @@ def round_ndarray(
7073
else:
7174
code_is_odd = (isignificand != 0) & _isodd(expval + bias)
7275

76+
if rnd == RoundMode.TowardZero:
77+
should_round_away = np.zeros_like(delta, dtype=bool)
7378
if rnd == RoundMode.TowardPositive:
74-
round_up = ~is_negative & (delta > 0)
75-
elif rnd == RoundMode.TowardNegative:
76-
round_up = is_negative & (delta > 0)
77-
elif rnd == RoundMode.TiesToAway:
78-
round_up = delta >= 0.5
79-
elif rnd == RoundMode.TiesToEven:
80-
round_up = (delta > 0.5) | ((delta == 0.5) & code_is_odd)
81-
else:
82-
round_up = np.zeros_like(delta, dtype=bool)
83-
84-
isignificand = np.where(round_up, isignificand + 1, isignificand)
79+
should_round_away = ~is_negative & (delta > 0)
80+
if rnd == RoundMode.TowardNegative:
81+
should_round_away = is_negative & (delta > 0)
82+
if rnd == RoundMode.TiesToAway:
83+
should_round_away = delta >= 0.5
84+
if rnd == RoundMode.TiesToEven:
85+
should_round_away = (delta > 0.5) | ((delta == 0.5) & code_is_odd)
86+
if rnd == RoundMode.Stochastic:
87+
assert srbits is not None
88+
should_round_away = delta > (0.5 + srbits) * 2.0**-srnumbits
89+
if rnd == RoundMode.StochasticFast:
90+
assert srbits is not None
91+
should_round_away = delta > srbits * 2.0**-srnumbits
92+
93+
isignificand = np.where(should_round_away, isignificand + 1, isignificand)
8594

8695
result = np.where(finite_nonzero, np.ldexp(isignificand, expval), absv)
8796

test/test_round.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,10 @@ def test_round_roundtrip(round_float: Callable, fi: FormatInfo) -> None:
500500
(311, 3, 7.0 / 8),
501501
),
502502
)
503-
def test_stochastic_rounding(v: float, srnumbits: int, expected_up: float) -> None:
503+
@pytest.mark.parametrize("impl", ("scalar", "array"))
504+
def test_stochastic_rounding(
505+
impl: bool, v: float, srnumbits: int, expected_up: float
506+
) -> None:
504507
fi = format_info_ocp_e5m2
505508

506509
v0 = round_float(fi, v, RoundMode.TowardNegative)
@@ -510,20 +513,26 @@ def test_stochastic_rounding(v: float, srnumbits: int, expected_up: float) -> No
510513
expected_up_count = expected_up * n
511514

512515
srbits = np.random.randint(0, 2**srnumbits, size=(n,))
513-
count_v1 = 0
514-
for k in range(n):
515-
r = round_float(
516-
fi,
517-
v,
518-
RoundMode.Stochastic,
519-
sat=False,
520-
srbits=srbits[k],
521-
srnumbits=srnumbits,
522-
)
523-
if r == v1:
524-
count_v1 += 1
525-
else:
526-
assert r == v0
516+
if impl == "scalar":
517+
count_v1 = 0
518+
for k in range(n):
519+
r = round_float(
520+
fi,
521+
v,
522+
RoundMode.Stochastic,
523+
sat=False,
524+
srbits=srbits[k],
525+
srnumbits=srnumbits,
526+
)
527+
if r == v1:
528+
count_v1 += 1
529+
else:
530+
assert r == v0
531+
else:
532+
vs = np.full(n, v)
533+
rs = round_ndarray(fi, vs, RoundMode.Stochastic, False, srbits, srnumbits)
534+
assert np.all((rs == v0) | (rs == v1))
535+
count_v1 = np.sum(rs == v1)
527536

528537
print(f"SRBits={srnumbits}, observed = {count_v1}, expected = {expected_up_count} ")
529538
# e.g. if expected is 1250/10000, want to be within 0.5,1.5

0 commit comments

Comments
 (0)