Skip to content

Commit e0a8ad4

Browse files
committed
Add mx test
1 parent 9f0fe30 commit e0a8ad4

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

test/test_microxcaling.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
2+
3+
import pytest
4+
5+
import numpy as np
6+
7+
import torch
8+
9+
from mx.mx_ops import quantize_mx_op
10+
from mx.formats import ElemFormat
11+
12+
13+
from gfloat import (
14+
BlockFormatInfo,
15+
encode_block,
16+
decode_block,
17+
encode_float,
18+
decode_float,
19+
round_float,
20+
RoundMode,
21+
)
22+
from gfloat.formats import *
23+
24+
25+
@pytest.mark.parametrize(
26+
("mx_round,gf_round"),
27+
[("even", RoundMode.TiesToEven), ("nearest", RoundMode.TiesToAway)],
28+
)
29+
@pytest.mark.parametrize(
30+
("mx_etype,gf_etype"),
31+
[
32+
(ElemFormat.fp6_e3m2, format_info_ocp_e3m2),
33+
(ElemFormat.fp4_e2m1, format_info_ocp_e2m1),
34+
],
35+
ids=str,
36+
)
37+
def test_mx(mx_round, gf_round, mx_etype, gf_etype):
38+
A = torch.arange(32) / 2 - 5
39+
40+
mx_specs = dict(
41+
block_size=32,
42+
scale_bits=8,
43+
shared_exp_method="max",
44+
mx_flush_fp32_subnorms=False,
45+
custom_cuda=False,
46+
)
47+
48+
mx_dq = quantize_mx_op(A, mx_specs, mx_etype, axes=0, round=mx_round)
49+
50+
fi = BlockFormatInfo("test", gf_etype, 32, format_info_ocp_e8m0)
51+
52+
amax = A.abs().max()
53+
q_log2scale = torch.floor(torch.log2(amax)).item() - fi.etype.emax
54+
q_scale = 2**q_log2scale
55+
56+
print(f"{q_scale=}")
57+
58+
enc = list(encode_block(fi, q_scale, (a.item() for a in A), gf_round))
59+
print(f"{enc=}")
60+
print("decoded_scale=", decode_float(fi.stype, enc[0]).fval)
61+
print("decoded_vals=", list(decode_float(fi.etype, e).fval for e in enc[1:]))
62+
print(
63+
"all_vals=",
64+
*(
65+
str(decode_float(fi.etype, i).fval) + ("" if i & 1 else "e")
66+
for i in range(fi.etype.code_of_max + 1)
67+
),
68+
)
69+
gf_dq = list(decode_block(fi, enc))
70+
print("input=", *(str(v.item()) for v in A))
71+
print("mx_dq=", *(str(v.item()) for v in mx_dq))
72+
print("gf_dq=", *(str(v) for v in gf_dq))
73+
74+
np.testing.assert_allclose(gf_dq, mx_dq)

0 commit comments

Comments
 (0)