29
29
@pytest .mark .parametrize (
30
30
("mx_etype,gf_etype" ),
31
31
[
32
+ (ElemFormat .int8 , format_info_ocp_int8 ),
32
33
(ElemFormat .fp6_e3m2 , format_info_ocp_e3m2 ),
33
34
(ElemFormat .fp4_e2m1 , format_info_ocp_e2m1 ),
34
35
],
@@ -41,7 +42,7 @@ def test_mx(
41
42
gf_etype : FormatInfo ,
42
43
) -> None :
43
44
## Input tensor
44
- A = torch .arange (32 ) / 2 - 5
45
+ A = np .arange (32 ) / 2 - 5
45
46
46
47
## Compute MX quantization
47
48
# Declare block format
@@ -54,19 +55,19 @@ def test_mx(
54
55
)
55
56
56
57
# Compute scale, encode, decode
57
- mx_dq = quantize_mx_op (A , mx_specs , mx_etype , axes = 0 , round = mx_round )
58
+ mx_dq = quantize_mx_op (torch . tensor ( A ) , mx_specs , mx_etype , axes = 0 , round = mx_round )
58
59
59
60
## Compute GFloat quantization
60
61
# Declare block format
61
62
fi = BlockFormatInfo ("test" , gf_etype , 32 , format_info_ocp_e8m0 )
62
63
63
64
# Compute scale - this is not considered GFloat's job, but could easily be added
64
- amax = A . abs (). max ( )
65
- q_log2scale = torch .floor (torch .log2 (amax )). item ( ) - fi .etype .emax
65
+ amax = np . max ( np . abs ( A ) )
66
+ q_log2scale = np .floor (np .log2 (amax )) - fi .etype .emax
66
67
q_scale = 2 ** q_log2scale
67
68
68
69
# Apply scale to encode and decode
69
- enc = encode_block (fi , q_scale , ( a . item () for a in A ) , gf_round )
70
+ enc = encode_block (fi , q_scale , A , gf_round )
70
71
gf_dq = list (decode_block (fi , enc ))
71
72
72
73
## Compare
0 commit comments