@@ -40,8 +40,10 @@ def test_mx(
40
40
mx_etype : ElemFormat ,
41
41
gf_etype : FormatInfo ,
42
42
) -> None :
43
+ # Input tensor
43
44
A = torch .arange (32 ) / 2 - 5
44
45
46
+ # Compute MX quantization
45
47
mx_specs = dict (
46
48
block_size = 32 ,
47
49
scale_bits = 8 ,
@@ -52,28 +54,17 @@ def test_mx(
52
54
53
55
mx_dq = quantize_mx_op (A , mx_specs , mx_etype , axes = 0 , round = mx_round )
54
56
57
+ # Compute GFloat quantization
55
58
fi = BlockFormatInfo ("test" , gf_etype , 32 , format_info_ocp_e8m0 )
56
59
60
+ # Compute scale - this is not considered GFloat's job
57
61
amax = A .abs ().max ()
58
62
q_log2scale = torch .floor (torch .log2 (amax )).item () - fi .etype .emax
59
63
q_scale = 2 ** q_log2scale
60
64
61
- print (f"{ q_scale = } " )
62
-
63
- enc = list (encode_block (fi , q_scale , (a .item () for a in A ), gf_round ))
64
- print (f"{ enc = } " )
65
- print ("decoded_scale=" , decode_float (fi .stype , enc [0 ]).fval )
66
- print ("decoded_vals=" , list (decode_float (fi .etype , e ).fval for e in enc [1 :]))
67
- print (
68
- "all_vals=" ,
69
- * (
70
- str (decode_float (fi .etype , i ).fval ) + ("" if i & 1 else "e" )
71
- for i in range (fi .etype .code_of_max + 1 )
72
- ),
73
- )
65
+ # Apply scale to encode and decode
66
+ enc = encode_block (fi , q_scale , (a .item () for a in A ), gf_round )
74
67
gf_dq = list (decode_block (fi , enc ))
75
- print ("input=" , * (str (v .item ()) for v in A ))
76
- print ("mx_dq=" , * (str (v .item ()) for v in mx_dq ))
77
- print ("gf_dq=" , * (str (v ) for v in gf_dq ))
78
68
69
+ # Compare
79
70
np .testing .assert_allclose (gf_dq , mx_dq )
0 commit comments