|
1 | 1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved.
|
2 | 2 |
|
3 |
| -from typing import Type, Callable |
| 3 | +from typing import Type, Callable, Iterator, Tuple |
4 | 4 |
|
5 | 5 | import ml_dtypes
|
6 | 6 | import numpy as np
|
|
10 | 10 | from gfloat.formats import *
|
11 | 11 |
|
12 | 12 |
|
13 |
| -def rnd_scalar(fi, v, mode=RoundMode.TiesToEven, sat: bool = False): |
| 13 | +def rnd_scalar( |
| 14 | + fi: FormatInfo, v: float, mode: RoundMode = RoundMode.TiesToEven, sat: bool = False |
| 15 | +) -> float: |
14 | 16 | return round_float(fi, v, mode, sat)
|
15 | 17 |
|
16 | 18 |
|
17 |
| -def rnd_array(fi, v, mode=RoundMode.TiesToEven, sat: bool = False): |
| 19 | +def rnd_array( |
| 20 | + fi: FormatInfo, v: float, mode: RoundMode = RoundMode.TiesToEven, sat: bool = False |
| 21 | +) -> float: |
18 | 22 | return round_ndarray(fi, np.array([v]), mode, sat).item()
|
19 | 23 |
|
20 | 24 |
|
@@ -394,7 +398,7 @@ def test_round(fi: FormatInfo) -> None:
|
394 | 398 | round(v0 + 0.6*dv) == v1
|
395 | 399 | """
|
396 | 400 |
|
397 |
| - def get_vals(): |
| 401 | + def get_vals() -> Iterator[Tuple[float, float]]: |
398 | 402 | for i in some_positive_codepoints:
|
399 | 403 | v0 = decode_float(fi, i + 0).fval
|
400 | 404 | v1 = decode_float(fi, i + 1).fval
|
|
0 commit comments