Skip to content

Commit 9a48d99

Browse files
committed
Add int8 test, tidyfy
1 parent b63611a commit 9a48d99

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

test/test_microxcaling.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
@pytest.mark.parametrize(
3030
("mx_etype,gf_etype"),
3131
[
32+
(ElemFormat.int8, format_info_ocp_int8),
3233
(ElemFormat.fp6_e3m2, format_info_ocp_e3m2),
3334
(ElemFormat.fp4_e2m1, format_info_ocp_e2m1),
3435
],
@@ -41,7 +42,7 @@ def test_mx(
4142
gf_etype: FormatInfo,
4243
) -> None:
4344
## Input tensor
44-
A = torch.arange(32) / 2 - 5
45+
A = np.arange(32) / 2 - 5
4546

4647
## Compute MX quantization
4748
# Declare block format
@@ -54,19 +55,19 @@ def test_mx(
5455
)
5556

5657
# Compute scale, encode, decode
57-
mx_dq = quantize_mx_op(A, mx_specs, mx_etype, axes=0, round=mx_round)
58+
mx_dq = quantize_mx_op(torch.tensor(A), mx_specs, mx_etype, axes=0, round=mx_round)
5859

5960
## Compute GFloat quantization
6061
# Declare block format
6162
fi = BlockFormatInfo("test", gf_etype, 32, format_info_ocp_e8m0)
6263

6364
# Compute scale - this is not considered GFloat's job, but could easily be added
64-
amax = A.abs().max()
65-
q_log2scale = torch.floor(torch.log2(amax)).item() - fi.etype.emax
65+
amax = np.max(np.abs(A))
66+
q_log2scale = np.floor(np.log2(amax)) - fi.etype.emax
6667
q_scale = 2**q_log2scale
6768

6869
# Apply scale to encode and decode
69-
enc = encode_block(fi, q_scale, (a.item() for a in A), gf_round)
70+
enc = encode_block(fi, q_scale, A, gf_round)
7071
gf_dq = list(decode_block(fi, enc))
7172

7273
## Compare

0 commit comments

Comments
 (0)