Skip to content

Commit 12bedbe

Browse files
committed
mypy
1 parent d8a2b01 commit 12bedbe

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

src/gfloat/decode_ndarray.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
22

3+
from types import ModuleType
34
import numpy as np
5+
from .types import FormatInfo
46

5-
from .types import FloatClass, FloatValue, FormatInfo
67

7-
8-
def decode_ndarray(fi: FormatInfo, codes: np.ndarray, np=np) -> np.ndarray:
8+
def decode_ndarray(
9+
fi: FormatInfo, codes: np.ndarray, np: ModuleType = np
10+
) -> np.ndarray:
911
r"""
1012
Vectorized version of :function:`decode_float`
1113

test/test_decode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def dec(code: int) -> float:
3434

3535

3636
@pytest.mark.parametrize("method", methods)
37-
def test_spot_check_ocp_e5m2(method) -> None:
37+
def test_spot_check_ocp_e5m2(method: str) -> None:
3838
fi = format_info_ocp_e5m2
3939
dec = get_method(method, fi)
4040
fclass = lambda code: decode_float(fi, code).fclass
@@ -51,7 +51,7 @@ def test_spot_check_ocp_e5m2(method) -> None:
5151

5252

5353
@pytest.mark.parametrize("method", methods)
54-
def test_spot_check_ocp_e4m3(method) -> None:
54+
def test_spot_check_ocp_e4m3(method: str) -> None:
5555
fi = format_info_ocp_e4m3
5656
dec = get_method(method, fi)
5757
assert dec(0x40) == 2.0

0 commit comments

Comments
 (0)