Skip to content

Commit 0ed495c

Browse files
authored
Merge pull request #14 from graphcore-research/text-microxscaling
Add microxscaling tests, Add `quantize_block`.
2 parents 3348d2a + 3308197 commit 0ed495c

File tree

14 files changed

+266
-42
lines changed

14 files changed

+266
-42
lines changed

docs/source/api.rst

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,22 @@ API
55

66
.. module:: gfloat
77

8-
Functions
9-
---------
8+
Scalar Functions
9+
----------------
1010

1111
.. autofunction:: decode_float
1212
.. autofunction:: round_float
1313
.. autofunction:: encode_float
1414

15+
Block format functions
16+
----------------------
17+
1518
.. autofunction:: decode_block
1619
.. autofunction:: encode_block
20+
.. autofunction:: quantize_block
21+
22+
.. autofunction:: compute_scale_amax
23+
1724

1825
Classes
1926
-------

docs/source/conf.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@
2424
"myst_nb",
2525
]
2626

27+
autodoc_typehints = "none" # We have them in the parameter descriptors
28+
autodoc_typehints_format = "short"
29+
python_use_unqualified_type_names = True
30+
31+
autodoc_type_aliases = {
32+
"Iterable": "Iterable",
33+
"npt.ArrayLike": "ArrayLike",
34+
"npt.NDArray": "NDArray",
35+
}
36+
2737
autodoc_default_options = {
2838
"member-order": "bysource",
2939
}

docs/source/formats.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,13 @@ IEEE WG P3109 Formats
3131
---------------------
3232

3333
.. autofunction:: format_info_p3109
34+
35+
Block Formats
36+
---------------------
37+
38+
.. autodata:: format_info_mxfp8_e5m2
39+
.. autodata:: format_info_mxfp8_e4m3
40+
.. autodata:: format_info_mxfp6_e3m2
41+
.. autodata:: format_info_mxfp6_e2m3
42+
.. autodata:: format_info_mxfp4_e2m1
43+
.. autodata:: format_info_mxint8

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,8 @@ optional-dependencies = {dev = {file = ["requirements-dev.txt"]}}
3333
[tool.black]
3434
line-length = 88
3535
fast = true
36+
37+
[tool.mypy]
38+
[[tool.mypy.overrides]]
39+
module = "mx.*"
40+
ignore_missing_imports = true

requirements-dev.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
# Requirements for tests
22
pytest
33
ml_dtypes
4+
mx @ git+https://github.com/microsoft/microxcaling
45

56
# Requirements for development
67
pre-commit
78
black
89
mypy
910
black[jupyter]
11+
isort
1012

1113
# Requirements for docs
1214
sphinx==7.1.2

src/gfloat/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
22

3-
from .block import BlockFormatInfo, decode_block, encode_block
3+
from .block import (
4+
BlockFormatInfo,
5+
compute_scale_amax,
6+
decode_block,
7+
encode_block,
8+
quantize_block,
9+
)
410
from .decode import decode_float
511
from .printing import float_pow2str, float_tilde_unless_roundtrip_str
612
from .round import encode_float, round_float

src/gfloat/block.py

Lines changed: 90 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
# https://en.wikipedia.org/wiki/Block_floating_point
55

66
from dataclasses import dataclass
7-
from typing import Iterable
7+
from typing import Callable, Iterable
8+
9+
import numpy as np
10+
import numpy.typing as npt
811

912
from .decode import decode_float
10-
from .round import encode_float, round_float
13+
from .round import RoundMode, encode_float, round_float
1114
from .types import FormatInfo
1215

1316

@@ -45,8 +48,12 @@ def block_size_bytes(self) -> int:
4548
assert bits % 8 == 0
4649
return bits // 8
4750

51+
@property
52+
def __name__(self) -> str:
53+
return self.name
54+
4855
def __str__(self) -> str:
49-
return f"{self.name}"
56+
return f"BlockFormatInfo:{self.name})"
5057

5158

5259
def decode_block(fi: BlockFormatInfo, block: Iterable[int]) -> Iterable[float]:
@@ -78,13 +85,18 @@ def decode_block(fi: BlockFormatInfo, block: Iterable[int]) -> Iterable[float]:
7885

7986

8087
def encode_block(
81-
fi: BlockFormatInfo, scale: float, vals: Iterable[float]
88+
fi: BlockFormatInfo,
89+
scale: float,
90+
vals: Iterable[float],
91+
round: RoundMode = RoundMode.TiesToEven,
8292
) -> Iterable[int]:
8393
"""
84-
Encode a :paramref:`block` of bytes into block Format descibed by :paramref:`fi`
94+
Encode float :paramref:`vals` into block Format described by :paramref:`fi`
8595
86-
The :paramref:`scale` is explicitly passed, and is converted to `1/(1/scale)`
87-
before rounding to the target format.
96+
The :paramref:`scale` is explicitly passed, and the :paramref:`vals` are
97+
assumed to already be multiplied by `1/scale`.
98+
That is, this is pure encoding, scaling is computed and applied elsewhere
99+
(see e.g. :func:`quantize_block`).
88100
89101
It is checked for overflow in the target format,
90102
and will raise an exception if it does.
@@ -93,24 +105,88 @@ def encode_block(
93105
fi (BlockFormatInfo): Describes the target block format
94106
scale (float): Scale to be recorded in the block
95107
vals (Iterable[float]): Input block
108+
round (RoundMode): Rounding mode to use, defaults to `TiesToEven`
96109
97110
Returns:
98111
A sequence of ints representing the encoded values.
99112
100113
Raises:
101114
ValueError: The scale overflows the target scale encoding format.
102115
"""
103-
# TODO: this should not do any multiplication - the scale is to be recorded not applied.
104-
recip_scale = 1 / scale
105-
scale = 1 / recip_scale
106116

107-
if scale > fi.stype.max:
108-
raise ValueError(f"Scaled {scale} too large for {fi.stype}")
117+
if scale > fi.stype.max or scale < fi.stype.min:
118+
raise ValueError(f"Scaled {scale} out of range for {fi.stype}")
119+
120+
sat = True # Saturate elements if out of range
109121

110122
def enc(ty: FormatInfo, x: float) -> int:
111-
return encode_float(ty, round_float(ty, x))
123+
return encode_float(ty, round_float(ty, x, round, sat))
112124

113125
yield enc(fi.stype, scale)
114126

115127
for val in vals:
116-
yield enc(fi.etype, recip_scale * val)
128+
yield enc(fi.etype, val)
129+
130+
131+
ComputeScaleCallable = Callable[[float, npt.ArrayLike], float]
132+
133+
134+
def compute_scale_amax(emax: float, vals: npt.ArrayLike) -> float:
135+
"""
136+
Compute a scale factor such that :paramref:`vals` can be scaled to the
137+
range [0, 2**emax]. That is, `scale` is computed such that the largest
138+
exponent in the array `vals * scale` will be `emax`.
139+
140+
The scale is clipped to the range 2**[-127, 127].
141+
142+
If all values are zero, any scale value smaller than emax would be accurate,
143+
but returning the smallest possible means that quick checks on the magnitude
144+
to identify near-zero blocks will also find the all-zero blocks.
145+
146+
Args:
147+
emax (float): Maximum exponent to appear in `vals * scale`
148+
vals (ArrayLike): Input block
149+
150+
Returns:
151+
A float such that `vals * scale` has exponents less than or equal to `emax`.
152+
153+
Note:
154+
If all vals are zero, 1.0 is returned.
155+
"""
156+
amax = np.max(np.abs(vals))
157+
if amax == 0.0:
158+
q_log2scale = -127.0
159+
else:
160+
q_log2scale = np.floor(np.log2(amax)) - emax
161+
q_log2scale = np.clip(q_log2scale, -127.0, 127.0)
162+
return 2.0**q_log2scale
163+
164+
165+
def quantize_block(
166+
fi: BlockFormatInfo,
167+
vals: npt.NDArray[np.float64],
168+
compute_scale: ComputeScaleCallable,
169+
round: RoundMode = RoundMode.TiesToEven,
170+
) -> npt.NDArray[np.float64]:
171+
"""
172+
Encode and decode a block of :paramref:`vals` of bytes into
173+
block format described by :paramref:`fi`
174+
175+
Args:
176+
fi (BlockFormatInfo): Describes the target block format
177+
vals (numpy.array): Input block
178+
compute_scale ((float, ArrayLike) -> float):
179+
Callable to compute the scale, defaults to :func:`compute_scale_amax`
180+
round (RoundMode): Rounding mode to use, defaults to `TiesToEven`
181+
182+
Returns:
183+
An array of floats representing the quantized values.
184+
185+
Raises:
186+
ValueError: The scale overflows the target scale encoding format.
187+
"""
188+
189+
q_scale = compute_scale(fi.etype.emax, vals)
190+
scaled_vals = vals / q_scale
191+
enc = encode_block(fi, q_scale, scaled_vals, round)
192+
return np.fromiter(decode_block(fi, enc), float)

src/gfloat/formats.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
#: FormatInfo for IEEE-754 Binary32 format
77
format_info_binary32 = FormatInfo(
8-
name="binary32",
8+
name="format_info_binary32",
99
k=32,
1010
precision=24,
1111
emax=127,
@@ -19,7 +19,7 @@
1919

2020
#: FormatInfo for IEEE-754 Binary16 format
2121
format_info_binary16 = FormatInfo(
22-
name="binary16",
22+
name="format_info_binary16",
2323
k=16,
2424
precision=11,
2525
emax=15,
@@ -33,7 +33,7 @@
3333

3434
#: FormatInfo for Google BFloat16 format
3535
format_info_bfloat16 = FormatInfo(
36-
name="bfloat16",
36+
name="format_info_bfloat16",
3737
k=16,
3838
precision=8,
3939
emax=127,
@@ -47,7 +47,7 @@
4747

4848
#: FormatInfo for OCP E5M2 format
4949
format_info_ocp_e5m2 = FormatInfo(
50-
name="ocp_e5m2",
50+
name="format_info_ocp_e5m2",
5151
k=8,
5252
precision=3,
5353
emax=15,
@@ -61,7 +61,7 @@
6161

6262
#: FormatInfo for OCP E4M3 format
6363
format_info_ocp_e4m3 = FormatInfo(
64-
name="ocp_e4m3",
64+
name="format_info_ocp_e4m3",
6565
k=8,
6666
precision=4,
6767
emax=8,
@@ -75,7 +75,7 @@
7575

7676
#: FormatInfo for OCP MX E2M3 format
7777
format_info_ocp_e2m3 = FormatInfo(
78-
name="ocp_e2m3",
78+
name="format_info_ocp_e2m3",
7979
k=6,
8080
precision=4,
8181
emax=2,
@@ -89,7 +89,7 @@
8989

9090
#: FormatInfo for OCP MX E3M2 format
9191
format_info_ocp_e3m2 = FormatInfo(
92-
name="ocp_e3m2",
92+
name="format_info_ocp_e3m2",
9393
k=6,
9494
precision=3,
9595
emax=4,
@@ -103,7 +103,7 @@
103103

104104
#: FormatInfo for OCP MX E2M1 format
105105
format_info_ocp_e2m1 = FormatInfo(
106-
name="ocp_e2m1",
106+
name="format_info_ocp_e2m1",
107107
k=4,
108108
precision=2,
109109
emax=2,
@@ -117,7 +117,7 @@
117117

118118
#: FormatInfo for OCP MX E8M0 format
119119
format_info_ocp_e8m0 = FormatInfo(
120-
name="ocp_e8m0",
120+
name="format_info_ocp_e8m0",
121121
k=8,
122122
precision=1,
123123
emax=127,
@@ -131,7 +131,7 @@
131131

132132
#: FormatInfo for OCP MX INT8 format
133133
format_info_ocp_int8 = FormatInfo(
134-
name="ocp_int8",
134+
name="format_info_ocp_int8",
135135
k=8,
136136
precision=8,
137137
emax=0,
@@ -210,11 +210,11 @@ def format_info_p3109(precision: int) -> FormatInfo:
210210
# Block formats
211211

212212
format_info_mxfp8_e5m2 = BlockFormatInfo(
213-
"ocp_mxfp8_e5m2", format_info_ocp_e5m2, 32, format_info_ocp_e8m0
213+
"format_info_mxfp8_e5m2", format_info_ocp_e5m2, 32, format_info_ocp_e8m0
214214
)
215215

216216
format_info_mxfp8_e4m3 = BlockFormatInfo(
217-
"ocp_mxfp8_e4m3", format_info_ocp_e4m3, 32, format_info_ocp_e8m0
217+
"format_info_mxfp8_e4m3", format_info_ocp_e4m3, 32, format_info_ocp_e8m0
218218
)
219219

220220
format_info_mxfp6_e3m2 = BlockFormatInfo(
@@ -233,11 +233,15 @@ def format_info_p3109(precision: int) -> FormatInfo:
233233
"format_info_mxfp4_e2m1", format_info_ocp_e2m1, 32, format_info_ocp_e8m0
234234
)
235235

236+
format_info_mxint8 = BlockFormatInfo(
237+
"format_info_mxint8", format_info_ocp_int8, 32, format_info_ocp_e8m0
238+
)
239+
236240
all_block_formats = [
237241
format_info_mxfp8_e5m2,
238242
format_info_mxfp8_e4m3,
239243
format_info_mxfp6_e3m2,
240244
format_info_mxfp6_e2m3,
241245
format_info_mxfp4_e2m1,
242-
format_info_mxfp4_e2m1,
246+
format_info_mxint8,
243247
]

src/gfloat/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,5 +400,9 @@ def is_all_subnormal(self) -> bool:
400400
"""
401401
return (self.expBits == 0) and self.has_subnormals
402402

403+
@property
404+
def __name__(self) -> str:
405+
return self.name
406+
403407
def __str__(self) -> str:
404408
return f"{self.name}"

0 commit comments

Comments
 (0)