diff --git a/docs/source/03-value-tables.ipynb b/docs/source/03-value-tables.ipynb index 731430d..24c2586 100644 --- a/docs/source/03-value-tables.ipynb +++ b/docs/source/03-value-tables.ipynb @@ -386,7 +386,7 @@ " border-collapse: collapse;\n", " }\n", "\n", - " \n", + "\n", " table {\n", " margin: 0pt;\n", " font-family: monospace;\n", @@ -399,24 +399,24 @@ " height: 4ex;\n", " vertical-align: top;\n", " }\n", - " \n", + "\n", " td {\n", " text-align: left;\n", " border: solid 2px #ccc;\n", " width: 49.0%;\n", " }\n", - " \n", + "\n", " .special {\n", " color: #874723;\n", " }\n", - " \n", + "\n", " .subnormal {\n", " color: #0121a7;\n", " }\n", - " \n", + "\n", " .normal {\n", " }\n", - " \n", + "\n", " @media (prefers-color-scheme: dark) {\n", " .special {\n", " color: orange;\n", @@ -425,7 +425,7 @@ " .subnormal {\n", " color: cyan;\n", " }\n", - " \n", + "\n", " .normal {\n", " }\n", " }\n", @@ -550,7 +550,7 @@ " border-collapse: collapse;\n", " }\n", "\n", - " \n", + "\n", " table {\n", " margin: 0pt;\n", " font-family: monospace;\n", @@ -563,24 +563,24 @@ " height: 4ex;\n", " vertical-align: top;\n", " }\n", - " \n", + "\n", " td {\n", " text-align: left;\n", " border: solid 2px #ccc;\n", " width: 49.0%;\n", " }\n", - " \n", + "\n", " .special {\n", " color: #874723;\n", " }\n", - " \n", + "\n", " .subnormal {\n", " color: #0121a7;\n", " }\n", - " \n", + "\n", " .normal {\n", " }\n", - " \n", + "\n", " @media (prefers-color-scheme: dark) {\n", " .special {\n", " color: orange;\n", @@ -589,7 +589,7 @@ " .subnormal {\n", " color: cyan;\n", " }\n", - " \n", + "\n", " .normal {\n", " }\n", " }\n", @@ -710,7 +710,7 @@ " border-collapse: collapse;\n", " }\n", "\n", - " \n", + "\n", " table {\n", " margin: 0pt;\n", " font-family: monospace;\n", @@ -723,24 +723,24 @@ " height: 4ex;\n", " vertical-align: top;\n", " }\n", - " \n", + "\n", " td {\n", " text-align: left;\n", " border: solid 2px #ccc;\n", " width: 49.0%;\n", " }\n", - " \n", + "\n", " .special {\n", " color: #874723;\n", " }\n", - " \n", + "\n", " .subnormal {\n", " color: #0121a7;\n", " }\n", - " \n", + "\n", " .normal {\n", " }\n", - " \n", + "\n", " @media (prefers-color-scheme: dark) {\n", " .special {\n", " color: orange;\n", @@ -749,7 +749,7 @@ " .subnormal {\n", " color: cyan;\n", " }\n", - " \n", + "\n", " .normal {\n", " }\n", " }\n", @@ -870,7 +870,7 @@ " border-collapse: collapse;\n", " }\n", "\n", - " \n", + "\n", " table {\n", " margin: 0pt;\n", " font-family: monospace;\n", @@ -883,24 +883,24 @@ " height: 4ex;\n", " vertical-align: top;\n", " }\n", - " \n", + "\n", " td {\n", " text-align: left;\n", " border: solid 2px #ccc;\n", " width: 49.0%;\n", " }\n", - " \n", + "\n", " .special {\n", " color: #874723;\n", " }\n", - " \n", + "\n", " .subnormal {\n", " color: #0121a7;\n", " }\n", - " \n", + "\n", " .normal {\n", " }\n", - " \n", + "\n", " @media (prefers-color-scheme: dark) {\n", " .special {\n", " color: orange;\n", @@ -909,7 +909,7 @@ " .subnormal {\n", " color: cyan;\n", " }\n", - " \n", + "\n", " .normal {\n", " }\n", " }\n", @@ -1039,7 +1039,7 @@ " border-collapse: collapse;\n", " }\n", "\n", - " \n", + "\n", " table {\n", " margin: 0pt;\n", " font-family: monospace;\n", @@ -1052,24 +1052,24 @@ " height: 4ex;\n", " vertical-align: top;\n", " }\n", - " \n", + "\n", " td {\n", " text-align: left;\n", " border: solid 2px #ccc;\n", " width: 49.0%;\n", " }\n", - " \n", + "\n", " .special {\n", " color: #874723;\n", " }\n", - " \n", + "\n", " .subnormal {\n", " color: #0121a7;\n", " }\n", - " \n", + "\n", " .normal {\n", " }\n", - " \n", + "\n", " @media (prefers-color-scheme: dark) {\n", " .special {\n", " color: orange;\n", @@ -1078,7 +1078,7 @@ " .subnormal {\n", " color: cyan;\n", " }\n", - " \n", + "\n", " .normal {\n", " }\n", " }\n", @@ -1391,7 +1391,7 @@ " border-collapse: collapse;\n", " }\n", "\n", - " \n", + "\n", " table {\n", " margin: 0pt;\n", " font-family: monospace;\n", @@ -1404,24 +1404,24 @@ " height: 4ex;\n", " vertical-align: top;\n", " }\n", - " \n", + "\n", " td {\n", " text-align: left;\n", " border: solid 2px #ccc;\n", " width: 24.5%;\n", " }\n", - " \n", + "\n", " .special {\n", " color: #874723;\n", " }\n", - " \n", + "\n", " .subnormal {\n", " color: #0121a7;\n", " }\n", - " \n", + "\n", " .normal {\n", " }\n", - " \n", + "\n", " @media (prefers-color-scheme: dark) {\n", " .special {\n", " color: orange;\n", @@ -1430,7 +1430,7 @@ " .subnormal {\n", " color: cyan;\n", " }\n", - " \n", + "\n", " .normal {\n", " }\n", " }\n", @@ -1931,7 +1931,7 @@ " border-collapse: collapse;\n", " }\n", "\n", - " \n", + "\n", " table {\n", " margin: 0pt;\n", " font-family: monospace;\n", @@ -1944,24 +1944,24 @@ " height: 4ex;\n", " vertical-align: top;\n", " }\n", - " \n", + "\n", " td {\n", " text-align: left;\n", " border: solid 2px #ccc;\n", " width: 24.5%;\n", " }\n", - " \n", + "\n", " .special {\n", " color: #874723;\n", " }\n", - " \n", + "\n", " .subnormal {\n", " color: #0121a7;\n", " }\n", - " \n", + "\n", " .normal {\n", " }\n", - " \n", + "\n", " @media (prefers-color-scheme: dark) {\n", " .special {\n", " color: orange;\n", @@ -1970,7 +1970,7 @@ " .subnormal {\n", " color: cyan;\n", " }\n", - " \n", + "\n", " .normal {\n", " }\n", " }\n", @@ -2480,7 +2480,7 @@ " border-collapse: collapse;\n", " }\n", "\n", - " \n", + "\n", " table {\n", " margin: 0pt;\n", " font-family: monospace;\n", @@ -2493,24 +2493,24 @@ " height: 4ex;\n", " vertical-align: top;\n", " }\n", - " \n", + "\n", " td {\n", " text-align: left;\n", " border: solid 2px #ccc;\n", " width: 24.5%;\n", " }\n", - " \n", + "\n", " .special {\n", " color: #874723;\n", " }\n", - " \n", + "\n", " .subnormal {\n", " color: #0121a7;\n", " }\n", - " \n", + "\n", " .normal {\n", " }\n", - " \n", + "\n", " @media (prefers-color-scheme: dark) {\n", " .special {\n", " color: orange;\n", @@ -2519,7 +2519,7 @@ " .subnormal {\n", " color: cyan;\n", " }\n", - " \n", + "\n", " .normal {\n", " }\n", " }\n", @@ -3031,7 +3031,7 @@ " border-collapse: collapse;\n", " }\n", "\n", - " \n", + "\n", " table {\n", " margin: 0pt;\n", " font-family: monospace;\n", @@ -3044,24 +3044,24 @@ " height: 4ex;\n", " vertical-align: top;\n", " }\n", - " \n", + "\n", " td {\n", " text-align: left;\n", " border: solid 2px #ccc;\n", " width: 49.0%;\n", " }\n", - " \n", + "\n", " .special {\n", " color: #874723;\n", " }\n", - " \n", + "\n", " .subnormal {\n", " color: #0121a7;\n", " }\n", - " \n", + "\n", " .normal {\n", " }\n", - " \n", + "\n", " @media (prefers-color-scheme: dark) {\n", " .special {\n", " color: orange;\n", @@ -3070,7 +3070,7 @@ " .subnormal {\n", " color: cyan;\n", " }\n", - " \n", + "\n", " .normal {\n", " }\n", " }\n", @@ -3161,7 +3161,7 @@ " border-collapse: collapse;\n", " }\n", "\n", - " \n", + "\n", " table {\n", " margin: 0pt;\n", " font-family: monospace;\n", @@ -3174,24 +3174,24 @@ " height: 4ex;\n", " vertical-align: top;\n", " }\n", - " \n", + "\n", " td {\n", " text-align: left;\n", " border: solid 2px #ccc;\n", " width: 49.0%;\n", " }\n", - " \n", + "\n", " .special {\n", " color: #874723;\n", " }\n", - " \n", + "\n", " .subnormal {\n", " color: #0121a7;\n", " }\n", - " \n", + "\n", " .normal {\n", " }\n", - " \n", + "\n", " @media (prefers-color-scheme: dark) {\n", " .special {\n", " color: orange;\n", @@ -3200,7 +3200,7 @@ " .subnormal {\n", " color: cyan;\n", " }\n", - " \n", + "\n", " .normal {\n", " }\n", " }\n", @@ -3264,7 +3264,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "gfloat", "language": "python", "name": "python3" }, @@ -3278,7 +3278,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.9" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/src/gfloat/round_ndarray.py b/src/gfloat/round_ndarray.py index 76cc7cb..43d3776 100644 --- a/src/gfloat/round_ndarray.py +++ b/src/gfloat/round_ndarray.py @@ -28,6 +28,20 @@ def _ldexp(v: npt.NDArray, s: npt.NDArray) -> npt.NDArray: return xp.where(v < 1.0, vlo, vhi) +def _frexp(v: npt.NDArray) -> npt.NDArray: + xp = array_api_compat.array_namespace(v) + if ( + array_api_compat.is_torch_array(v) # type: ignore + or array_api_compat.is_jax_array(v) # type: ignore + or array_api_compat.is_numpy_array(v) + ): + return xp.frexp(v) + + # Beware #49 + expval = xp.astype(xp.floor(xp.log2(v)), xp.int64) + return (xp.nan, expval) + + def round_ndarray( fi: FormatInfo, v: npt.NDArray, @@ -88,7 +102,7 @@ def to_int(x: npt.NDArray) -> npt.NDArray: def to_float(x: npt.NDArray) -> npt.NDArray: return xp.astype(x, v.dtype) - expval = to_int(xp.floor(xp.log2(absv_masked))) + expval = _frexp(absv_masked)[1] - 1 if fi.has_subnormals: expval = xp_maximum(expval, 1 - bias) diff --git a/test/test_round.py b/test/test_round.py index badea4d..9a91824 100644 --- a/test/test_round.py +++ b/test/test_round.py @@ -595,3 +595,21 @@ def test_stochastic_rounding_scalar_eq_array( # Ensure faithful rounding if alpha < 1.0: assert ((val_array == v0) | (val_array == v1)).all() + + +def test_large_bfloat(): + # from https://github.com/graphcore-research/gfloat/pull/49 + + a = 6.6461399789245764e35 + b = 6.620178494631905e35 + + assert b < a + rounded_a = round_float(format_info_bfloat16, a, RoundMode.TowardZero) + rounded_b = round_float(format_info_bfloat16, b, RoundMode.TowardZero) + + assert rounded_b <= rounded_a + + rounded_a = round_ndarray(format_info_bfloat16, np.array([a]), RoundMode.TowardZero) + rounded_b = round_ndarray(format_info_bfloat16, np.array([b]), RoundMode.TowardZero) + + assert all(rounded_b <= rounded_a)