@@ -24,7 +24,7 @@ def dec(code: int) -> float:
24
24
if method == "array" :
25
25
26
26
def dec (code : int ) -> float :
27
- asnp = np .tile (np .array (code , dtype = np .int64 ), (2 , 3 ))
27
+ asnp = np .tile (np .array (code , dtype = np .uint64 ), (2 , 3 ))
28
28
vals = decode_ndarray (fi , asnp )
29
29
val = vals .item (0 )
30
30
np .testing .assert_equal (val , vals )
@@ -50,10 +50,10 @@ def test_spot_check_ocp_e5m2(method) -> None:
50
50
assert fclass (0x00 ) == FloatClass .ZERO
51
51
52
52
53
- def test_spot_check_ocp_e4m3 () -> None :
53
+ @pytest .mark .parametrize ("method" , methods )
54
+ def test_spot_check_ocp_e4m3 (method ) -> None :
54
55
fi = format_info_ocp_e4m3
55
- dec = lambda code : decode_float (fi , code ).fval
56
-
56
+ dec = get_method (method , fi )
57
57
assert dec (0x40 ) == 2.0
58
58
assert dec (0x01 ) == 2.0 ** - 9
59
59
assert _isnegzero (dec (0x80 ))
@@ -62,9 +62,10 @@ def test_spot_check_ocp_e4m3() -> None:
62
62
assert np .floor (np .log2 (dec (0x7E ))) == fi .emax
63
63
64
64
65
- def test_spot_check_p3109_p3 () -> None :
65
+ @pytest .mark .parametrize ("method" , methods )
66
+ def test_spot_check_p3109_p3 (method : str ) -> None :
66
67
fi = format_info_p3109 (3 )
67
- dec = lambda code : decode_float ( fi , code ). fval
68
+ dec = get_method ( method , fi )
68
69
69
70
assert dec (0x01 ) == 2.0 ** - 17
70
71
assert dec (0x40 ) == 1.0
@@ -73,9 +74,10 @@ def test_spot_check_p3109_p3() -> None:
73
74
assert np .floor (np .log2 (dec (0x7E ))) == fi .emax
74
75
75
76
76
- def test_spot_check_p3109_p1 () -> None :
77
+ @pytest .mark .parametrize ("method" , methods )
78
+ def test_spot_check_p3109_p1 (method : str ) -> None :
77
79
fi = format_info_p3109 (1 )
78
- dec = lambda code : decode_float ( fi , code ). fval
80
+ dec = get_method ( method , fi )
79
81
80
82
assert dec (0x01 ) == 2.0 ** - 62
81
83
assert dec (0x40 ) == 2.0
@@ -84,9 +86,10 @@ def test_spot_check_p3109_p1() -> None:
84
86
assert np .floor (np .log2 (dec (0x7E ))) == fi .emax
85
87
86
88
87
- def test_spot_check_binary16 () -> None :
89
+ @pytest .mark .parametrize ("method" , methods )
90
+ def test_spot_check_binary16 (method : str ) -> None :
88
91
fi = format_info_binary16
89
- dec = lambda code : decode_float ( fi , code ). fval
92
+ dec = get_method ( method , fi )
90
93
91
94
assert dec (0x3C00 ) == 1.0
92
95
assert dec (0x3C01 ) == 1.0 + 2 ** - 10
@@ -98,9 +101,10 @@ def test_spot_check_binary16() -> None:
98
101
assert np .isnan (dec (0x7FFF ))
99
102
100
103
101
- def test_spot_check_bfloat16 () -> None :
104
+ @pytest .mark .parametrize ("method" , methods )
105
+ def test_spot_check_bfloat16 (method : str ) -> None :
102
106
fi = format_info_bfloat16
103
- dec = lambda code : decode_float ( fi , code ). fval
107
+ dec = get_method ( method , fi )
104
108
105
109
assert dec (0x3F80 ) == 1
106
110
assert dec (0x4000 ) == 2
@@ -111,10 +115,11 @@ def test_spot_check_bfloat16() -> None:
111
115
assert np .isnan (dec (0x7FFF ))
112
116
113
117
114
- def test_spot_check_ocp_e2m3 () -> None :
118
+ @pytest .mark .parametrize ("method" , methods )
119
+ def test_spot_check_ocp_e2m3 (method : str ) -> None :
115
120
# Test against Table 4 in "OCP Microscaling Formats (MX) v1.0 Spec"
116
121
fi = format_info_ocp_e2m3
117
- dec = lambda code : decode_float ( fi , code ). fval
122
+ dec = get_method ( method , fi )
118
123
119
124
assert fi .max == 7.5
120
125
assert fi .smallest_subnormal == 0.125
@@ -128,10 +133,11 @@ def test_spot_check_ocp_e2m3() -> None:
128
133
assert _isnegzero (dec (0b100000 ))
129
134
130
135
131
- def test_spot_check_ocp_e3m2 () -> None :
136
+ @pytest .mark .parametrize ("method" , methods )
137
+ def test_spot_check_ocp_e3m2 (method : str ) -> None :
132
138
# Test against Table 4 in "OCP Microscaling Formats (MX) v1.0 Spec"
133
139
fi = format_info_ocp_e3m2
134
- dec = lambda code : decode_float ( fi , code ). fval
140
+ dec = get_method ( method , fi )
135
141
136
142
assert fi .max == 28.0
137
143
assert fi .smallest_subnormal == 0.0625
@@ -145,10 +151,11 @@ def test_spot_check_ocp_e3m2() -> None:
145
151
assert _isnegzero (dec (0b100000 ))
146
152
147
153
148
- def test_spot_check_ocp_e2m1 () -> None :
154
+ @pytest .mark .parametrize ("method" , methods )
155
+ def test_spot_check_ocp_e2m1 (method : str ) -> None :
149
156
# Test against Table 5 in "OCP Microscaling Formats (MX) v1.0 Spec"
150
157
fi = format_info_ocp_e2m1
151
- dec = lambda code : decode_float ( fi , code ). fval
158
+ dec = get_method ( method , fi )
152
159
153
160
assert fi .max == 6.0
154
161
assert fi .smallest_subnormal == 0.5
@@ -168,10 +175,11 @@ def test_spot_check_ocp_e2m1() -> None:
168
175
assert _isnegzero (dec (0b1000 ))
169
176
170
177
171
- def test_spot_check_ocp_e8m0 () -> None :
178
+ @pytest .mark .parametrize ("method" , methods )
179
+ def test_spot_check_ocp_e8m0 (method : str ) -> None :
172
180
# Test against Table 7 in "OCP Microscaling Formats (MX) v1.0 Spec"
173
181
fi = format_info_ocp_e8m0
174
- dec = lambda code : decode_float ( fi , code ). fval
182
+ dec = get_method ( method , fi )
175
183
fclass = lambda code : decode_float (fi , code ).fclass
176
184
assert fi .expBias == 127
177
185
assert fi .max == 2.0 ** 127
@@ -187,10 +195,11 @@ def test_spot_check_ocp_e8m0() -> None:
187
195
assert fclass (0x00 ) == FloatClass .NORMAL
188
196
189
197
190
- def test_spot_check_ocp_int8 () -> None :
198
+ @pytest .mark .parametrize ("method" , methods )
199
+ def test_spot_check_ocp_int8 (method : str ) -> None :
191
200
# Test against Table TODO in "OCP Microscaling Formats (MX) v1.0 Spec"
192
201
fi = format_info_ocp_int8
193
- dec = lambda code : decode_float ( fi , code ). fval
202
+ dec = get_method ( method , fi )
194
203
195
204
assert fi .max == 1.0 + 63.0 / 64
196
205
assert fi .smallest == 2.0 ** - 6
@@ -214,8 +223,9 @@ def test_specials(fi: FormatInfo) -> None:
214
223
215
224
216
225
@pytest .mark .parametrize ("fi" , all_formats )
217
- def test_specials_decode (fi : FormatInfo ) -> None :
218
- dec = lambda v : decode_float (fi , v ).fval
226
+ @pytest .mark .parametrize ("method" , methods )
227
+ def test_specials_decode (method : str , fi : FormatInfo ) -> None :
228
+ dec = get_method (method , fi )
219
229
220
230
if fi .has_zero :
221
231
assert dec (fi .code_of_zero ) == 0
0 commit comments