Skip to content

Commit b7c801a

Browse files
committed
comments
1 parent 6712804 commit b7c801a

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

test/test_microxcaling.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,11 @@ def test_mx(
4040
mx_etype: ElemFormat,
4141
gf_etype: FormatInfo,
4242
) -> None:
43-
# Input tensor
43+
## Input tensor
4444
A = torch.arange(32) / 2 - 5
4545

46-
# Compute MX quantization
46+
## Compute MX quantization
47+
# Declare block format
4748
mx_specs = dict(
4849
block_size=32,
4950
scale_bits=8,
@@ -52,12 +53,14 @@ def test_mx(
5253
custom_cuda=False,
5354
)
5455

56+
# Compute scale, encode, decode
5557
mx_dq = quantize_mx_op(A, mx_specs, mx_etype, axes=0, round=mx_round)
5658

57-
# Compute GFloat quantization
59+
## Compute GFloat quantization
60+
# Declare block format
5861
fi = BlockFormatInfo("test", gf_etype, 32, format_info_ocp_e8m0)
5962

60-
# Compute scale - this is not considered GFloat's job
63+
# Compute scale - this is not considered GFloat's job, but could easily be added
6164
amax = A.abs().max()
6265
q_log2scale = torch.floor(torch.log2(amax)).item() - fi.etype.emax
6366
q_scale = 2**q_log2scale
@@ -66,5 +69,5 @@ def test_mx(
6669
enc = encode_block(fi, q_scale, (a.item() for a in A), gf_round)
6770
gf_dq = list(decode_block(fi, enc))
6871

69-
# Compare
72+
## Compare
7073
np.testing.assert_allclose(gf_dq, mx_dq)

0 commit comments

Comments
 (0)