|
6 | 6 | import numpy as np
|
7 | 7 | import pytest
|
8 | 8 |
|
9 |
| -from gfloat import RoundMode, decode_float, round_float, round_ndarray |
| 9 | +from gfloat import RoundMode, decode_float, decode_ndarray, round_float, round_ndarray |
10 | 10 | from gfloat.formats import *
|
11 | 11 |
|
12 | 12 |
|
@@ -428,7 +428,7 @@ def get_vals() -> Iterator[Tuple[float, float]]:
|
428 | 428 | ]
|
429 | 429 |
|
430 | 430 |
|
431 |
| -def _linterp(a: float, b: float, t: float) -> float: |
| 431 | +def _linterp(a, b, t): # type: ignore[no-untyped-def] |
432 | 432 | return a * (1 - t) + b * t
|
433 | 433 |
|
434 | 434 |
|
@@ -539,3 +539,43 @@ def test_stochastic_rounding(
|
539 | 539 | # this is loose, but should still catch logic errors
|
540 | 540 | atol = n * 2.0 ** (-1 - srnumbits)
|
541 | 541 | np.testing.assert_allclose(count_v1, expected_up_count, atol=atol)
|
| 542 | + |
| 543 | + |
| 544 | +def test_stochastic_rounding_2() -> None: |
| 545 | + fi = format_info_p3109(3) |
| 546 | + |
| 547 | + v0 = decode_ndarray(fi, np.arange(255)) |
| 548 | + v1 = decode_ndarray(fi, np.arange(255) + 1) |
| 549 | + ok = np.isfinite(v0) & np.isfinite(v1) |
| 550 | + v0 = v0[ok] |
| 551 | + v1 = v1[ok] |
| 552 | + |
| 553 | + srnumbits = 3 |
| 554 | + for srbits in range(2**srnumbits): |
| 555 | + for alpha in (0, 0.3, 0.5, 0.6, 0.9, 1.25): |
| 556 | + v = _linterp(v0, v1, alpha) |
| 557 | + assert np.isfinite(v).all() |
| 558 | + val_array = round_ndarray( |
| 559 | + fi, |
| 560 | + v, |
| 561 | + RoundMode.Stochastic, |
| 562 | + sat=False, |
| 563 | + srbits=np.asarray(srbits), |
| 564 | + srnumbits=srnumbits, |
| 565 | + ) |
| 566 | + |
| 567 | + val_scalar = [ |
| 568 | + round_float( |
| 569 | + fi, |
| 570 | + v, |
| 571 | + RoundMode.Stochastic, |
| 572 | + sat=False, |
| 573 | + srbits=srbits, |
| 574 | + srnumbits=srnumbits, |
| 575 | + ) |
| 576 | + for v in v |
| 577 | + ] |
| 578 | + if alpha < 1.0: |
| 579 | + assert ((val_array == v0) | (val_array == v1)).all() |
| 580 | + |
| 581 | + np.testing.assert_equal(val_array, val_scalar) |
0 commit comments