Skip to content

Commit 5b9cf32

Browse files
committed
add test
1 parent 9625ca1 commit 5b9cf32

File tree

1 file changed

+42
-2
lines changed

1 file changed

+42
-2
lines changed

test/test_round.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
import pytest
88

9-
from gfloat import RoundMode, decode_float, round_float, round_ndarray
9+
from gfloat import RoundMode, decode_float, decode_ndarray, round_float, round_ndarray
1010
from gfloat.formats import *
1111

1212

@@ -428,7 +428,7 @@ def get_vals() -> Iterator[Tuple[float, float]]:
428428
]
429429

430430

431-
def _linterp(a: float, b: float, t: float) -> float:
431+
def _linterp(a, b, t): # type: ignore[no-untyped-def]
432432
return a * (1 - t) + b * t
433433

434434

@@ -539,3 +539,43 @@ def test_stochastic_rounding(
539539
# this is loose, but should still catch logic errors
540540
atol = n * 2.0 ** (-1 - srnumbits)
541541
np.testing.assert_allclose(count_v1, expected_up_count, atol=atol)
542+
543+
544+
def test_stochastic_rounding_2() -> None:
545+
fi = format_info_p3109(3)
546+
547+
v0 = decode_ndarray(fi, np.arange(255))
548+
v1 = decode_ndarray(fi, np.arange(255) + 1)
549+
ok = np.isfinite(v0) & np.isfinite(v1)
550+
v0 = v0[ok]
551+
v1 = v1[ok]
552+
553+
srnumbits = 3
554+
for srbits in range(2**srnumbits):
555+
for alpha in (0, 0.3, 0.5, 0.6, 0.9, 1.25):
556+
v = _linterp(v0, v1, alpha)
557+
assert np.isfinite(v).all()
558+
val_array = round_ndarray(
559+
fi,
560+
v,
561+
RoundMode.Stochastic,
562+
sat=False,
563+
srbits=np.asarray(srbits),
564+
srnumbits=srnumbits,
565+
)
566+
567+
val_scalar = [
568+
round_float(
569+
fi,
570+
v,
571+
RoundMode.Stochastic,
572+
sat=False,
573+
srbits=srbits,
574+
srnumbits=srnumbits,
575+
)
576+
for v in v
577+
]
578+
if alpha < 1.0:
579+
assert ((val_array == v0) | (val_array == v1)).all()
580+
581+
np.testing.assert_equal(val_array, val_scalar)

0 commit comments

Comments
 (0)