Skip to content

Commit 0707cb0

Browse files
committed
Mypy
1 parent f763f72 commit 0707cb0

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

src/gfloat/round_ndarray.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
22

3+
from types import ModuleType
34
from .types import FormatInfo, RoundMode
45
import numpy as np
56
import math
@@ -14,7 +15,7 @@ def round_ndarray(
1415
v: np.ndarray,
1516
rnd: RoundMode = RoundMode.TiesToEven,
1617
sat: bool = False,
17-
np=np,
18+
np: ModuleType = np,
1819
) -> np.ndarray:
1920
"""
2021
Vectorized version of round_float.

test/test_round.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
22

3-
from typing import Type, Callable
3+
from typing import Type, Callable, Iterator, Tuple
44

55
import ml_dtypes
66
import numpy as np
@@ -10,11 +10,15 @@
1010
from gfloat.formats import *
1111

1212

13-
def rnd_scalar(fi, v, mode=RoundMode.TiesToEven, sat: bool = False):
13+
def rnd_scalar(
14+
fi: FormatInfo, v: float, mode: RoundMode = RoundMode.TiesToEven, sat: bool = False
15+
) -> float:
1416
return round_float(fi, v, mode, sat)
1517

1618

17-
def rnd_array(fi, v, mode=RoundMode.TiesToEven, sat: bool = False):
19+
def rnd_array(
20+
fi: FormatInfo, v: float, mode: RoundMode = RoundMode.TiesToEven, sat: bool = False
21+
) -> float:
1822
return round_ndarray(fi, np.array([v]), mode, sat).item()
1923

2024

@@ -394,7 +398,7 @@ def test_round(fi: FormatInfo) -> None:
394398
round(v0 + 0.6*dv) == v1
395399
"""
396400

397-
def get_vals():
401+
def get_vals() -> Iterator[Tuple[float, float]]:
398402
for i in some_positive_codepoints:
399403
v0 = decode_float(fi, i + 0).fval
400404
v1 = decode_float(fi, i + 1).fval

0 commit comments

Comments
 (0)