6
6
from dataclasses import dataclass
7
7
from typing import Iterable , Callable
8
8
import numpy as np
9
+ import numpy .typing as npt
9
10
10
11
from .decode import decode_float
11
12
from .round import RoundMode , encode_float , round_float
@@ -122,7 +123,10 @@ def enc(ty: FormatInfo, x: float) -> int:
122
123
yield enc (fi .etype , val )
123
124
124
125
125
- def compute_scale_amax (etype_emax : float , vals : np .array ) -> float :
126
+ ComputeScaleCallable = Callable [[float , npt .ArrayLike ], float ]
127
+
128
+
129
+ def compute_scale_amax (etype_emax : float , vals : npt .ArrayLike ) -> float :
126
130
"""
127
131
Compute a scale factor such that :paramref:`vals` can be
128
132
quantized to the range [0, 2**etype_emax]
@@ -147,10 +151,10 @@ def compute_scale_amax(etype_emax: float, vals: np.array) -> float:
147
151
148
152
def quantize_block (
149
153
fi : BlockFormatInfo ,
150
- vals : np .array ,
151
- compute_scale : Callable [[ float , np . array ], float ] = compute_scale_amax ,
154
+ vals : npt . NDArray [ np .float64 ] ,
155
+ compute_scale : ComputeScaleCallable = compute_scale_amax ,
152
156
round : RoundMode = RoundMode .TiesToEven ,
153
- ) -> np .array :
157
+ ) -> npt . NDArray [ np .float64 ] :
154
158
"""
155
159
Encode and decode a block of :paramref:`vals` of bytes into block Format descibed by :paramref:`fi`
156
160
@@ -169,5 +173,6 @@ def quantize_block(
169
173
"""
170
174
171
175
q_scale = compute_scale_amax (fi .etype .emax , vals )
172
- enc = encode_block (fi , q_scale , vals / q_scale , round )
176
+ scaled_vals = vals / q_scale
177
+ enc = encode_block (fi , q_scale , scaled_vals , round )
173
178
return np .fromiter (decode_block (fi , enc ), float )
0 commit comments