Skip to content

Commit 0f8481b

Browse files
committed
Many more tests
1 parent 9828d73 commit 0f8481b

File tree

2 files changed

+50
-39
lines changed

2 files changed

+50
-39
lines changed

src/gfloat/decode_ndarray.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,37 +39,38 @@ def decode_ndarray(fi: FormatInfo, codes: np.ndarray, np=np) -> np.ndarray:
3939
signmask = None
4040
sign = 1.0
4141

42-
exp = (codes >> t) & ((1 << w) - 1)
42+
exp = ((codes >> t) & ((1 << w) - 1)).astype(np.int64)
4343
significand = codes & ((1 << t) - 1)
4444
if fi.is_twos_complement:
4545
significand = np.where(sign < 0, (1 << t) - significand, significand)
4646

4747
expBias = fi.expBias
4848

49-
iszero = (exp == 0) & (significand == 0) if fi.has_zero else False
50-
issubnormal = (exp == 0) & (significand != 0) if fi.has_subnormals else False
49+
iszero = (exp == 0) & (significand == 0) & fi.has_zero
50+
issubnormal = (exp == 0) & (significand != 0) & fi.has_subnormals
5151
isnormal = ~iszero & ~issubnormal
5252
expval = np.where(~isnormal, 1 - expBias, exp - expBias)
5353
fsignificand = np.where(~isnormal, significand * 2**-t, 1.0 + significand * 2**-t)
5454

5555
# Normal/Subnormal/Zero case, other values will be overwritten
5656
fval = np.where(iszero, 0.0, sign * fsignificand * 2.0**expval)
5757

58-
# All-bits-special exponent (ABSE)
59-
if w > 0:
60-
abse = exp == 2**w - 1
61-
min_i_with_nan = 2 ** (p - 1) - fi.num_high_nans
62-
fval = np.where(abse & (significand >= min_i_with_nan), np.nan, fval)
63-
if fi.has_infs:
64-
fval = np.where(
65-
abse & (significand == min_i_with_nan - 1), np.inf * sign, fval
66-
)
58+
if fi.has_infs:
59+
fval = np.where(codes == fi.code_of_posinf, np.inf, fval)
60+
fval = np.where(codes == fi.code_of_neginf, -np.inf, fval)
61+
62+
if fi.num_nans > 0:
63+
code_is_nan = codes == fi.code_of_nan
64+
if w > 0:
65+
# All-bits-special exponent (ABSE)
66+
abse = exp == 2**w - 1
67+
min_code_with_nan = 2 ** (p - 1) - fi.num_high_nans
68+
code_is_nan |= abse & (significand >= min_code_with_nan)
69+
70+
fval = np.where(code_is_nan, np.nan, fval)
6771

6872
# Negative zero
6973
if fi.has_nz:
7074
fval = np.where(iszero & (sign < 0), -0.0, fval)
71-
else:
72-
# Negative zero slot is nan
73-
fval = np.where(codes == fi.code_of_negzero, np.nan, fval)
7475

7576
return fval

test/test_decode.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def dec(code: int) -> float:
2424
if method == "array":
2525

2626
def dec(code: int) -> float:
27-
asnp = np.tile(np.array(code, dtype=np.int64), (2, 3))
27+
asnp = np.tile(np.array(code, dtype=np.uint64), (2, 3))
2828
vals = decode_ndarray(fi, asnp)
2929
val = vals.item(0)
3030
np.testing.assert_equal(val, vals)
@@ -50,10 +50,10 @@ def test_spot_check_ocp_e5m2(method) -> None:
5050
assert fclass(0x00) == FloatClass.ZERO
5151

5252

53-
def test_spot_check_ocp_e4m3() -> None:
53+
@pytest.mark.parametrize("method", methods)
54+
def test_spot_check_ocp_e4m3(method) -> None:
5455
fi = format_info_ocp_e4m3
55-
dec = lambda code: decode_float(fi, code).fval
56-
56+
dec = get_method(method, fi)
5757
assert dec(0x40) == 2.0
5858
assert dec(0x01) == 2.0**-9
5959
assert _isnegzero(dec(0x80))
@@ -62,9 +62,10 @@ def test_spot_check_ocp_e4m3() -> None:
6262
assert np.floor(np.log2(dec(0x7E))) == fi.emax
6363

6464

65-
def test_spot_check_p3109_p3() -> None:
65+
@pytest.mark.parametrize("method", methods)
66+
def test_spot_check_p3109_p3(method: str) -> None:
6667
fi = format_info_p3109(3)
67-
dec = lambda code: decode_float(fi, code).fval
68+
dec = get_method(method, fi)
6869

6970
assert dec(0x01) == 2.0**-17
7071
assert dec(0x40) == 1.0
@@ -73,9 +74,10 @@ def test_spot_check_p3109_p3() -> None:
7374
assert np.floor(np.log2(dec(0x7E))) == fi.emax
7475

7576

76-
def test_spot_check_p3109_p1() -> None:
77+
@pytest.mark.parametrize("method", methods)
78+
def test_spot_check_p3109_p1(method: str) -> None:
7779
fi = format_info_p3109(1)
78-
dec = lambda code: decode_float(fi, code).fval
80+
dec = get_method(method, fi)
7981

8082
assert dec(0x01) == 2.0**-62
8183
assert dec(0x40) == 2.0
@@ -84,9 +86,10 @@ def test_spot_check_p3109_p1() -> None:
8486
assert np.floor(np.log2(dec(0x7E))) == fi.emax
8587

8688

87-
def test_spot_check_binary16() -> None:
89+
@pytest.mark.parametrize("method", methods)
90+
def test_spot_check_binary16(method: str) -> None:
8891
fi = format_info_binary16
89-
dec = lambda code: decode_float(fi, code).fval
92+
dec = get_method(method, fi)
9093

9194
assert dec(0x3C00) == 1.0
9295
assert dec(0x3C01) == 1.0 + 2**-10
@@ -98,9 +101,10 @@ def test_spot_check_binary16() -> None:
98101
assert np.isnan(dec(0x7FFF))
99102

100103

101-
def test_spot_check_bfloat16() -> None:
104+
@pytest.mark.parametrize("method", methods)
105+
def test_spot_check_bfloat16(method: str) -> None:
102106
fi = format_info_bfloat16
103-
dec = lambda code: decode_float(fi, code).fval
107+
dec = get_method(method, fi)
104108

105109
assert dec(0x3F80) == 1
106110
assert dec(0x4000) == 2
@@ -111,10 +115,11 @@ def test_spot_check_bfloat16() -> None:
111115
assert np.isnan(dec(0x7FFF))
112116

113117

114-
def test_spot_check_ocp_e2m3() -> None:
118+
@pytest.mark.parametrize("method", methods)
119+
def test_spot_check_ocp_e2m3(method: str) -> None:
115120
# Test against Table 4 in "OCP Microscaling Formats (MX) v1.0 Spec"
116121
fi = format_info_ocp_e2m3
117-
dec = lambda code: decode_float(fi, code).fval
122+
dec = get_method(method, fi)
118123

119124
assert fi.max == 7.5
120125
assert fi.smallest_subnormal == 0.125
@@ -128,10 +133,11 @@ def test_spot_check_ocp_e2m3() -> None:
128133
assert _isnegzero(dec(0b100000))
129134

130135

131-
def test_spot_check_ocp_e3m2() -> None:
136+
@pytest.mark.parametrize("method", methods)
137+
def test_spot_check_ocp_e3m2(method: str) -> None:
132138
# Test against Table 4 in "OCP Microscaling Formats (MX) v1.0 Spec"
133139
fi = format_info_ocp_e3m2
134-
dec = lambda code: decode_float(fi, code).fval
140+
dec = get_method(method, fi)
135141

136142
assert fi.max == 28.0
137143
assert fi.smallest_subnormal == 0.0625
@@ -145,10 +151,11 @@ def test_spot_check_ocp_e3m2() -> None:
145151
assert _isnegzero(dec(0b100000))
146152

147153

148-
def test_spot_check_ocp_e2m1() -> None:
154+
@pytest.mark.parametrize("method", methods)
155+
def test_spot_check_ocp_e2m1(method: str) -> None:
149156
# Test against Table 5 in "OCP Microscaling Formats (MX) v1.0 Spec"
150157
fi = format_info_ocp_e2m1
151-
dec = lambda code: decode_float(fi, code).fval
158+
dec = get_method(method, fi)
152159

153160
assert fi.max == 6.0
154161
assert fi.smallest_subnormal == 0.5
@@ -168,10 +175,11 @@ def test_spot_check_ocp_e2m1() -> None:
168175
assert _isnegzero(dec(0b1000))
169176

170177

171-
def test_spot_check_ocp_e8m0() -> None:
178+
@pytest.mark.parametrize("method", methods)
179+
def test_spot_check_ocp_e8m0(method: str) -> None:
172180
# Test against Table 7 in "OCP Microscaling Formats (MX) v1.0 Spec"
173181
fi = format_info_ocp_e8m0
174-
dec = lambda code: decode_float(fi, code).fval
182+
dec = get_method(method, fi)
175183
fclass = lambda code: decode_float(fi, code).fclass
176184
assert fi.expBias == 127
177185
assert fi.max == 2.0**127
@@ -187,10 +195,11 @@ def test_spot_check_ocp_e8m0() -> None:
187195
assert fclass(0x00) == FloatClass.NORMAL
188196

189197

190-
def test_spot_check_ocp_int8() -> None:
198+
@pytest.mark.parametrize("method", methods)
199+
def test_spot_check_ocp_int8(method: str) -> None:
191200
# Test against Table TODO in "OCP Microscaling Formats (MX) v1.0 Spec"
192201
fi = format_info_ocp_int8
193-
dec = lambda code: decode_float(fi, code).fval
202+
dec = get_method(method, fi)
194203

195204
assert fi.max == 1.0 + 63.0 / 64
196205
assert fi.smallest == 2.0**-6
@@ -214,8 +223,9 @@ def test_specials(fi: FormatInfo) -> None:
214223

215224

216225
@pytest.mark.parametrize("fi", all_formats)
217-
def test_specials_decode(fi: FormatInfo) -> None:
218-
dec = lambda v: decode_float(fi, v).fval
226+
@pytest.mark.parametrize("method", methods)
227+
def test_specials_decode(method: str, fi: FormatInfo) -> None:
228+
dec = get_method(method, fi)
219229

220230
if fi.has_zero:
221231
assert dec(fi.code_of_zero) == 0

0 commit comments

Comments
 (0)