Skip to content

Commit f19efb5

Browse files
committed
mypy
1 parent 28dbe01 commit f19efb5

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

src/gfloat/block.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from dataclasses import dataclass
77
from typing import Iterable, Callable
88
import numpy as np
9+
import numpy.typing as npt
910

1011
from .decode import decode_float
1112
from .round import RoundMode, encode_float, round_float
@@ -122,7 +123,10 @@ def enc(ty: FormatInfo, x: float) -> int:
122123
yield enc(fi.etype, val)
123124

124125

125-
def compute_scale_amax(etype_emax: float, vals: np.array) -> float:
126+
ComputeScaleCallable = Callable[[float, npt.ArrayLike], float]
127+
128+
129+
def compute_scale_amax(etype_emax: float, vals: npt.ArrayLike) -> float:
126130
"""
127131
Compute a scale factor such that :paramref:`vals` can be
128132
quantized to the range [0, 2**etype_emax]
@@ -147,10 +151,10 @@ def compute_scale_amax(etype_emax: float, vals: np.array) -> float:
147151

148152
def quantize_block(
149153
fi: BlockFormatInfo,
150-
vals: np.array,
151-
compute_scale: Callable[[float, np.array], float] = compute_scale_amax,
154+
vals: npt.NDArray[np.float64],
155+
compute_scale: ComputeScaleCallable = compute_scale_amax,
152156
round: RoundMode = RoundMode.TiesToEven,
153-
) -> np.array:
157+
) -> npt.NDArray[np.float64]:
154158
"""
155159
Encode and decode a block of :paramref:`vals` of bytes into block Format descibed by :paramref:`fi`
156160
@@ -169,5 +173,6 @@ def quantize_block(
169173
"""
170174

171175
q_scale = compute_scale_amax(fi.etype.emax, vals)
172-
enc = encode_block(fi, q_scale, vals / q_scale, round)
176+
scaled_vals = vals / q_scale
177+
enc = encode_block(fi, q_scale, scaled_vals, round)
173178
return np.fromiter(decode_block(fi, enc), float)

0 commit comments

Comments
 (0)