Skip to content

Commit 30fc8de

Browse files
committed
Add quantize_block
1 parent 9a48d99 commit 30fc8de

File tree

4 files changed

+76
-37
lines changed

4 files changed

+76
-37
lines changed

src/gfloat/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
22

3-
from .block import BlockFormatInfo, decode_block, encode_block
3+
from .block import (
4+
BlockFormatInfo,
5+
decode_block,
6+
encode_block,
7+
quantize_block,
8+
compute_scale_amax,
9+
)
410
from .decode import decode_float
511
from .printing import float_pow2str, float_tilde_unless_roundtrip_str
612
from .round import encode_float, round_float

src/gfloat/block.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
# https://en.wikipedia.org/wiki/Block_floating_point
55

66
from dataclasses import dataclass
7-
from typing import Iterable
7+
from typing import Iterable, Callable
8+
import numpy as np
89

910
from .decode import decode_float
10-
from .round import encode_float, round_float, RoundMode
11+
from .round import RoundMode, encode_float, round_float
1112
from .types import FormatInfo
1213

1314

@@ -84,10 +85,12 @@ def encode_block(
8485
round: RoundMode = RoundMode.TiesToEven,
8586
) -> Iterable[int]:
8687
"""
87-
Encode a :paramref:`block` of bytes into block Format descibed by :paramref:`fi`
88+
Encode float :paramref:`vals` into block Format descibed by :paramref:`fi`
8889
89-
The :paramref:`scale` is explicitly passed, and is converted to `1/(1/scale)`
90-
before rounding to the target format.
90+
The :paramref:`scale` is explicitly passed, and the :paramref:`vals` are
91+
assumed to already be multiplied by `1/scale`.
92+
That is, this is pure encoding, scaling is computed and applied elsewhere
93+
(see e.g. :funcref:`quantize_block`).
9194
9295
It is checked for overflow in the target format,
9396
and will raise an exception if it does.
@@ -105,11 +108,6 @@ def encode_block(
105108
ValueError: The scale overflows the target scale encoding format.
106109
"""
107110

108-
# TODO: this should really not do any multiplication -
109-
# the scale is to be recorded not applied.
110-
recip_scale = 1 / scale
111-
scale = 1 / recip_scale
112-
113111
if scale > fi.stype.max:
114112
raise ValueError(f"Scaled {scale} too large for {fi.stype}")
115113

@@ -121,4 +119,55 @@ def enc(ty: FormatInfo, x: float) -> int:
121119
yield enc(fi.stype, scale)
122120

123121
for val in vals:
124-
yield enc(fi.etype, recip_scale * val)
122+
yield enc(fi.etype, val)
123+
124+
125+
def compute_scale_amax(etype_emax: float, vals: np.array) -> float:
126+
"""
127+
Compute a scale factor such that :paramref:`vals` can be
128+
quantized to the range [0, 2**etype_emax]
129+
130+
Args:
131+
etype_emax (float): Maximum exponent to appear in `vals * scale`
132+
vals (numpy.array): Input block
133+
134+
Returns:
135+
A float such that `vals * scale` has exponents less than or equal to `etype_emax`.
136+
137+
Note:
138+
If all vals are zero, 1.0 is returned.
139+
"""
140+
amax = np.max(np.abs(vals))
141+
if amax == 0.0:
142+
# Array is all zeros - 1.0 is a good scale value
143+
return 1.0
144+
q_log2scale = np.floor(np.log2(amax)) - etype_emax
145+
return 2.0**q_log2scale
146+
147+
148+
def quantize_block(
149+
fi: BlockFormatInfo,
150+
vals: np.array,
151+
compute_scale: Callable[[float, np.array], float] = compute_scale_amax,
152+
round: RoundMode = RoundMode.TiesToEven,
153+
) -> np.array:
154+
"""
155+
Encode and decode a block of :paramref:`vals` of bytes into block Format descibed by :paramref:`fi`
156+
157+
Args:
158+
fi (BlockFormatInfo): Describes the target block format
159+
vals (numpy.array): Input block
160+
compute_scale ((float, np.array) -> float):
161+
Callable to compute the scale
162+
round (RoundMode): Rounding mode to use, defaults to `TiesToEven`
163+
164+
Returns:
165+
An array of floats representing the quantized values.
166+
167+
Raises:
168+
ValueError: The scale overflows the target scale encoding format.
169+
"""
170+
171+
q_scale = compute_scale_amax(fi.etype.emax, vals)
172+
enc = encode_block(fi, q_scale, vals / q_scale, round)
173+
return np.fromiter(decode_block(fi, enc), float)

test/test_block.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def test_blocks(fi: BlockFormatInfo) -> None:
1313
vals = np.linspace(-37.0, 42.0, 32)
1414

1515
scale = 8.0
16-
block = list(encode_block(fi, scale, vals))
16+
block = list(encode_block(fi, scale, vals / scale))
1717
decoded_vals = list(decode_block(fi, block))
1818

1919
atol = 2 * scale * fi.etype.eps

test/test_microxcaling.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,7 @@
1010
from mx.formats import ElemFormat
1111

1212

13-
from gfloat import (
14-
BlockFormatInfo,
15-
encode_block,
16-
decode_block,
17-
encode_float,
18-
decode_float,
19-
round_float,
20-
RoundMode,
21-
)
13+
from gfloat import BlockFormatInfo, RoundMode, quantize_block, compute_scale_amax
2214
from gfloat.formats import *
2315

2416

@@ -41,11 +33,10 @@ def test_mx(
4133
mx_etype: ElemFormat,
4234
gf_etype: FormatInfo,
4335
) -> None:
44-
## Input tensor
36+
# Input tensor
4537
A = np.arange(32) / 2 - 5
4638

47-
## Compute MX quantization
48-
# Declare block format
39+
# MX: Declare block format
4940
mx_specs = dict(
5041
block_size=32,
5142
scale_bits=8,
@@ -54,21 +45,14 @@ def test_mx(
5445
custom_cuda=False,
5546
)
5647

57-
# Compute scale, encode, decode
48+
# MX: Quantize
5849
mx_dq = quantize_mx_op(torch.tensor(A), mx_specs, mx_etype, axes=0, round=mx_round)
5950

60-
## Compute GFloat quantization
61-
# Declare block format
51+
# GFloat: Declare block format
6252
fi = BlockFormatInfo("test", gf_etype, 32, format_info_ocp_e8m0)
6353

64-
# Compute scale - this is not considered GFloat's job, but could easily be added
65-
amax = np.max(np.abs(A))
66-
q_log2scale = np.floor(np.log2(amax)) - fi.etype.emax
67-
q_scale = 2**q_log2scale
68-
69-
# Apply scale to encode and decode
70-
enc = encode_block(fi, q_scale, A, gf_round)
71-
gf_dq = list(decode_block(fi, enc))
54+
# GFloat: Quantize
55+
gf_dq = quantize_block(fi, A, compute_scale_amax, gf_round)
7256

73-
## Compare
57+
# Compare
7458
np.testing.assert_allclose(gf_dq, mx_dq)

0 commit comments

Comments
 (0)