4
4
# https://en.wikipedia.org/wiki/Block_floating_point
5
5
6
6
from dataclasses import dataclass
7
- from typing import Iterable
7
+ from typing import Iterable , Callable
8
+ import numpy as np
8
9
9
10
from .decode import decode_float
10
- from .round import encode_float , round_float , RoundMode
11
+ from .round import RoundMode , encode_float , round_float
11
12
from .types import FormatInfo
12
13
13
14
@@ -84,10 +85,12 @@ def encode_block(
84
85
round : RoundMode = RoundMode .TiesToEven ,
85
86
) -> Iterable [int ]:
86
87
"""
87
- Encode a :paramref:`block` of bytes into block Format descibed by :paramref:`fi`
88
+ Encode float :paramref:`vals` into block Format descibed by :paramref:`fi`
88
89
89
- The :paramref:`scale` is explicitly passed, and is converted to `1/(1/scale)`
90
- before rounding to the target format.
90
+ The :paramref:`scale` is explicitly passed, and the :paramref:`vals` are
91
+ assumed to already be multiplied by `1/scale`.
92
+ That is, this is pure encoding, scaling is computed and applied elsewhere
93
+ (see e.g. :funcref:`quantize_block`).
91
94
92
95
It is checked for overflow in the target format,
93
96
and will raise an exception if it does.
@@ -105,11 +108,6 @@ def encode_block(
105
108
ValueError: The scale overflows the target scale encoding format.
106
109
"""
107
110
108
- # TODO: this should really not do any multiplication -
109
- # the scale is to be recorded not applied.
110
- recip_scale = 1 / scale
111
- scale = 1 / recip_scale
112
-
113
111
if scale > fi .stype .max :
114
112
raise ValueError (f"Scaled { scale } too large for { fi .stype } " )
115
113
@@ -121,4 +119,55 @@ def enc(ty: FormatInfo, x: float) -> int:
121
119
yield enc (fi .stype , scale )
122
120
123
121
for val in vals :
124
- yield enc (fi .etype , recip_scale * val )
122
+ yield enc (fi .etype , val )
123
+
124
+
125
+ def compute_scale_amax (etype_emax : float , vals : np .array ) -> float :
126
+ """
127
+ Compute a scale factor such that :paramref:`vals` can be
128
+ quantized to the range [0, 2**etype_emax]
129
+
130
+ Args:
131
+ etype_emax (float): Maximum exponent to appear in `vals * scale`
132
+ vals (numpy.array): Input block
133
+
134
+ Returns:
135
+ A float such that `vals * scale` has exponents less than or equal to `etype_emax`.
136
+
137
+ Note:
138
+ If all vals are zero, 1.0 is returned.
139
+ """
140
+ amax = np .max (np .abs (vals ))
141
+ if amax == 0.0 :
142
+ # Array is all zeros - 1.0 is a good scale value
143
+ return 1.0
144
+ q_log2scale = np .floor (np .log2 (amax )) - etype_emax
145
+ return 2.0 ** q_log2scale
146
+
147
+
148
+ def quantize_block (
149
+ fi : BlockFormatInfo ,
150
+ vals : np .array ,
151
+ compute_scale : Callable [[float , np .array ], float ] = compute_scale_amax ,
152
+ round : RoundMode = RoundMode .TiesToEven ,
153
+ ) -> np .array :
154
+ """
155
+ Encode and decode a block of :paramref:`vals` of bytes into block Format descibed by :paramref:`fi`
156
+
157
+ Args:
158
+ fi (BlockFormatInfo): Describes the target block format
159
+ vals (numpy.array): Input block
160
+ compute_scale ((float, np.array) -> float):
161
+ Callable to compute the scale
162
+ round (RoundMode): Rounding mode to use, defaults to `TiesToEven`
163
+
164
+ Returns:
165
+ An array of floats representing the quantized values.
166
+
167
+ Raises:
168
+ ValueError: The scale overflows the target scale encoding format.
169
+ """
170
+
171
+ q_scale = compute_scale_amax (fi .etype .emax , vals )
172
+ enc = encode_block (fi , q_scale , vals / q_scale , round )
173
+ return np .fromiter (decode_block (fi , enc ), float )
0 commit comments