diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 24ad948d..5f5a3066 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,12 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace - id: debug-statements - repo: https://github.com/google/pyink - rev: 23.3.1 + rev: 23.10.0 hooks: - id: pyink language_version: python3.9 diff --git a/README.md b/README.md index ba34d55e..9250905c 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ * `float8_e4m3fnuz` * `float8_e5m2` * `float8_e5m2fnuz` + * `float8_p3109_p
` - `int4` and `uint4`: low precision integer types. See below for specifications of these number formats. @@ -107,6 +108,20 @@ This type has the following characteristics: * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s - `0b10000000` * denormals when exponent is 0 +### float8_p3109_p
+
+These types represent the types under discussion in IEEE working group P3109,
+"Arithmetic Formats for Machine Learning ", parameterized by precision $p$.
+
+These type has the following characteristics:
+ * Precision $p$: $2 < p < 6$
+ * Exponent bits, E: $8-p$
+ * Exponent bias: 2 ^ (E-1)
+ * Infinities: +Inf, -Inf
+ * No negative zero
+ * Single NaN in the -0 position: `0b10000000` == `0x80`
+ * Denormals when exponent is 0
+
## `int4` and `uint4`
4-bit integer types, where each element is represented unpacked (i.e., padded up
diff --git a/ml_dtypes/__init__.py b/ml_dtypes/__init__.py
index 7546ba96..16a6d8fb 100644
--- a/ml_dtypes/__init__.py
+++ b/ml_dtypes/__init__.py
@@ -12,19 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = '0.3.1' # Keep in sync with pyproject.toml:version
+__version__ = "0.3.1" # Keep in sync with pyproject.toml:version
__all__ = [
- '__version__',
- 'bfloat16',
- 'finfo',
- 'float8_e4m3b11fnuz',
- 'float8_e4m3fn',
- 'float8_e4m3fnuz',
- 'float8_e5m2',
- 'float8_e5m2fnuz',
- 'iinfo',
- 'int4',
- 'uint4',
+ "__version__",
+ "bfloat16",
+ "finfo",
+ "float8_e4m3b11fnuz",
+ "float8_e4m3fn",
+ "float8_e4m3fnuz",
+ "float8_e5m2",
+ "float8_e5m2fnuz",
+ "float8_p3109_p3",
+ "float8_p3109_p4",
+ "float8_p3109_p5",
+ "iinfo",
+ "int4",
+ "uint4",
]
from typing import Type
@@ -37,6 +40,9 @@
from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz
from ml_dtypes._ml_dtypes_ext import float8_e5m2
from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz
+from ml_dtypes._ml_dtypes_ext import float8_p3109_p3
+from ml_dtypes._ml_dtypes_ext import float8_p3109_p4
+from ml_dtypes._ml_dtypes_ext import float8_p3109_p5
from ml_dtypes._ml_dtypes_ext import int4
from ml_dtypes._ml_dtypes_ext import uint4
import numpy as np
@@ -47,6 +53,9 @@
float8_e4m3fnuz: Type[np.generic]
float8_e5m2: Type[np.generic]
float8_e5m2fnuz: Type[np.generic]
+float8_p3109_p3: Type[np.generic]
+float8_p3109_p4: Type[np.generic]
+float8_p3109_p5: Type[np.generic]
int4: Type[np.generic]
uint4: Type[np.generic]
diff --git a/ml_dtypes/_finfo.py b/ml_dtypes/_finfo.py
index 451f2766..e4c2fef9 100644
--- a/ml_dtypes/_finfo.py
+++ b/ml_dtypes/_finfo.py
@@ -22,6 +22,10 @@
from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz
from ml_dtypes._ml_dtypes_ext import float8_e5m2
from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz
+from ml_dtypes._ml_dtypes_ext import float8_p3109_p3
+from ml_dtypes._ml_dtypes_ext import float8_p3109_p4
+from ml_dtypes._ml_dtypes_ext import float8_p3109_p5
+
import numpy as np
_bfloat16_dtype = np.dtype(bfloat16)
@@ -30,6 +34,9 @@
_float8_e4m3fnuz_dtype = np.dtype(float8_e4m3fnuz)
_float8_e5m2_dtype = np.dtype(float8_e5m2)
_float8_e5m2fnuz_dtype = np.dtype(float8_e5m2fnuz)
+_float8_p3109_p3_dtype = np.dtype(float8_p3109_p3)
+_float8_p3109_p4_dtype = np.dtype(float8_p3109_p4)
+_float8_p3109_p5_dtype = np.dtype(float8_p3109_p5)
class _Bfloat16MachArLike:
@@ -86,6 +93,29 @@ def __init__(self):
self.smallest_subnormal = float8_e5m2fnuz(smallest_subnormal)
+class _Float8IEEEMachArLike:
+
+ def __init__(self, p):
+ # These are hard-coded in order to independently test against the computed values in the C++ implementation
+ if p == 3:
+ smallest_normal = float.fromhex("0x1p-15")
+ self.smallest_normal = float8_p3109_p3(smallest_normal)
+ smallest_subnormal = float.fromhex("0x1p-17")
+ self.smallest_subnormal = float8_p3109_p3(smallest_subnormal)
+
+ if p == 4:
+ smallest_normal = float.fromhex("0x1p-7")
+ self.smallest_normal = float8_p3109_p4(smallest_normal)
+ smallest_subnormal = float.fromhex("0x1p-10")
+ self.smallest_subnormal = float8_p3109_p4(smallest_subnormal)
+
+ if p == 5:
+ smallest_normal = float.fromhex("0x1p-3")
+ self.smallest_normal = float8_p3109_p5(smallest_normal)
+ smallest_subnormal = float.fromhex("0x1p-7")
+ self.smallest_subnormal = float8_p3109_p5(smallest_subnormal)
+
+
class finfo(np.finfo): # pylint: disable=invalid-name,missing-class-docstring
__doc__ = np.finfo.__doc__
_finfo_cache: Dict[np.dtype, np.finfo] = {}
@@ -360,55 +390,117 @@ def float_to_str(f):
# pylint: enable=protected-access
return obj
+ @staticmethod
+ def _float8_p3109_p_finfo(p):
+ def float_to_str(f):
+ return "%6.2e" % float(f)
+
+ # pylint: disable=protected-access
+ obj = object.__new__(np.finfo)
+
+ if p == 3:
+ dtype = float8_p3109_p3
+ obj.dtype = _float8_p3109_p3_dtype
+ elif p == 4:
+ dtype = float8_p3109_p4
+ obj.dtype = _float8_p3109_p4_dtype
+ elif p == 5:
+ dtype = float8_p3109_p5
+ obj.dtype = _float8_p3109_p5_dtype
+ else:
+ raise NotImplementedError()
+
+ obj._machar = _Float8IEEEMachArLike(p)
+
+ bias = 2 ** (7 - p)
+ tiny = obj._machar.smallest_normal
+ machep = 1 - p
+ eps = 2.0**machep
+ negep = -p
+ epsneg = 2.0**negep
+ max_ = (1 - 2 ** (1 - p)) * 2**bias # 1'0000 - 0'0010 = 0'1110
+
+ if p == 3:
+ assert tiny == float.fromhex("0x1p-15")
+ assert eps == float.fromhex("0x1p-2")
+ assert epsneg == float.fromhex("0x1p-3")
+ assert max_ == float.fromhex("0x1.8p15")
+ elif p == 4:
+ assert tiny == float.fromhex("0x1p-7")
+ assert eps == float.fromhex("0x1p-3")
+ assert epsneg == float.fromhex("0x1p-4")
+ assert max_ == float.fromhex("0x1.Cp7")
+ elif p == 5:
+ assert tiny == float.fromhex("0x1p-3")
+ assert eps == float.fromhex("0x1p-4")
+ assert epsneg == float.fromhex("0x1p-5")
+ assert max_ == float.fromhex("0x1.Ep3")
+ else:
+ raise NotImplementedError()
+
+ obj.bits = 8
+
+ # nextafter(1.0, Inf) - 1.0
+ obj.eps = dtype(eps)
+
+ # The exponent that yields eps.
+ obj.machep = machep
+
+ # 1.0 = nextafter(1.0, -Inf)
+ obj.epsneg = dtype(epsneg)
+
+ # The exponent that yields epsneg.
+ obj.negep = negep
+
+ # The largest representable number.
+ obj.max = dtype(max_)
+
+ # The smallest representable number, typically -max.
+ obj.min = dtype(-max_)
+
+ obj.nexp = 8 - p
+ obj.nmant = p - 1
+ obj.iexp = obj.nexp
+ obj.maxexp = bias
+ obj.minexp = 1 - bias
+
+ # The approximate number of decimal digits to which this kind of float is precise.
+ obj.precision = 1 if p < 4 else 2
+
+ # The approximate decimal resolution of this type, i.e., 10**-precision.
+ obj.resolution = dtype(10**-obj.precision)
+
+ if not hasattr(obj, "tiny"):
+ obj.tiny = dtype(tiny)
+ if not hasattr(obj, "smallest_normal"):
+ obj.smallest_normal = obj._machar.smallest_normal
+ obj.smallest_subnormal = obj._machar.smallest_subnormal
+
+ obj._str_tiny = float_to_str(tiny)
+ obj._str_smallest_normal = float_to_str(tiny)
+ obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
+ obj._str_max = float_to_str(max_)
+ obj._str_epsneg = float_to_str(epsneg)
+ obj._str_eps = float_to_str(eps)
+ obj._str_resolution = float_to_str(obj.resolution)
+ # pylint: enable=protected-access
+ return obj
+
def __new__(cls, dtype):
- if (
- isinstance(dtype, str)
- and dtype == "bfloat16"
- or dtype == _bfloat16_dtype
- ):
- if _bfloat16_dtype not in cls._finfo_cache:
- cls._finfo_cache[_bfloat16_dtype] = cls._bfloat16_finfo()
- return cls._finfo_cache[_bfloat16_dtype]
- if (
- isinstance(dtype, str)
- and dtype == "float8_e4m3b11fnuz"
- or dtype == _float8_e4m3b11fnuz_dtype
+ for ty, constructor in (
+ (_bfloat16_dtype, cls._bfloat16_finfo),
+ (_float8_e4m3b11fnuz_dtype, cls._float8_e4m3b11fnuz_finfo),
+ (_float8_e4m3fn_dtype, cls._float8_e4m3fn_finfo),
+ (_float8_e4m3fnuz_dtype, cls._float8_e4m3fnuz_finfo),
+ (_float8_e5m2_dtype, cls._float8_e5m2_finfo),
+ (_float8_e5m2fnuz_dtype, cls._float8_e5m2fnuz_finfo),
+ (_float8_p3109_p3_dtype, lambda: cls._float8_p3109_p_finfo(3)),
+ (_float8_p3109_p4_dtype, lambda: cls._float8_p3109_p_finfo(4)),
+ (_float8_p3109_p5_dtype, lambda: cls._float8_p3109_p_finfo(5)),
):
- if _float8_e4m3b11fnuz_dtype not in cls._finfo_cache:
- cls._finfo_cache[_float8_e4m3b11fnuz_dtype] = (
- cls._float8_e4m3b11fnuz_finfo()
- )
- return cls._finfo_cache[_float8_e4m3b11fnuz_dtype]
- if (
- isinstance(dtype, str)
- and dtype == "float8_e4m3fn"
- or dtype == _float8_e4m3fn_dtype
- ):
- if _float8_e4m3fn_dtype not in cls._finfo_cache:
- cls._finfo_cache[_float8_e4m3fn_dtype] = cls._float8_e4m3fn_finfo()
- return cls._finfo_cache[_float8_e4m3fn_dtype]
- if (
- isinstance(dtype, str)
- and dtype == "float8_e4m3fnuz"
- or dtype == _float8_e4m3fnuz_dtype
- ):
- if _float8_e4m3fnuz_dtype not in cls._finfo_cache:
- cls._finfo_cache[_float8_e4m3fnuz_dtype] = cls._float8_e4m3fnuz_finfo()
- return cls._finfo_cache[_float8_e4m3fnuz_dtype]
- if (
- isinstance(dtype, str)
- and dtype == "float8_e5m2"
- or dtype == _float8_e5m2_dtype
- ):
- if _float8_e5m2_dtype not in cls._finfo_cache:
- cls._finfo_cache[_float8_e5m2_dtype] = cls._float8_e5m2_finfo()
- return cls._finfo_cache[_float8_e5m2_dtype]
- if (
- isinstance(dtype, str)
- and dtype == "float8_e5m2fnuz"
- or dtype == _float8_e5m2fnuz_dtype
- ):
- if _float8_e5m2fnuz_dtype not in cls._finfo_cache:
- cls._finfo_cache[_float8_e5m2fnuz_dtype] = cls._float8_e5m2fnuz_finfo()
- return cls._finfo_cache[_float8_e5m2fnuz_dtype]
+ if isinstance(dtype, str) and dtype == ty.name or dtype == ty:
+ if ty not in cls._finfo_cache:
+ cls._finfo_cache[ty] = constructor()
+ return cls._finfo_cache[ty]
+
return super().__new__(cls, dtype)
diff --git a/ml_dtypes/_src/dtypes.cc b/ml_dtypes/_src/dtypes.cc
index 31ac72da..859cd84f 100644
--- a/ml_dtypes/_src/dtypes.cc
+++ b/ml_dtypes/_src/dtypes.cc
@@ -147,6 +147,51 @@ struct TypeDescriptor this_t;
+ using Base = float8_base operator-() const {
+ // TODO: use isnan()
+ if ((this->rep() & 0x7f) == 0x00) {
+ return *this;
+ }
+ return Base::operator-();
+ }
+
+ float8_p3109_p operator-(const float8_p3109_p & other) const {
+ return Base::operator-(other);
+ }
+
+ explicit EIGEN_DEVICE_FUNC operator bool() const { return this->rep() != 0; }
+};
+
+// -----------------------------------------
+
constexpr double ConstexprAbs(double x) { return x < 0.0 ? -x : x; }
constexpr double ConstexprCeil(double x) {
@@ -427,6 +466,7 @@ constexpr int MaxDigits10FromDigits(int digits) {
// C17 5.2.4.2.2p11:
// "minimum negative integer such that 10 raised to that power is in the range
// of normalized floating-point numbers"
+// TODO: https://en.cppreference.com/w/cpp/types/numeric_limits/max_exponent10 says "representable"
// ceil(log10(2**(emin - 1))) == ceil((emin - 1) * log10(2));
constexpr int MinExponent10FromMinExponent(int min_exponent) {
return static_cast min() {
+ return float8_p3109_p ::FromRep(1<<(p-1));
+ }
+ static constexpr float8_p3109_p lowest() {
+ return float8_p3109_p ::FromRep(0xfe);
+ }
+ static constexpr float8_p3109_p max() {
+ return float8_p3109_p ::FromRep(0x7e);
+ }
+ static constexpr float8_p3109_p epsilon() {
+ if constexpr (p < 5) {
+ constexpr int expeps = (-kMantissaBits + kExponentBias) << kMantissaBits;
+ return float8_p3109_p ::FromRep(expeps);
+ }
+ // p >= 5: eps is subnormal
+ return float8_p3109_p ::FromRep(uint8_t(1 << (kExponentBias - 1)));
+ }
+ static constexpr float8_p3109_p round_error() {
+ // Return 0.5
+ return float8_p3109_p ::FromRep((-1 + kExponentBias) << kMantissaBits);
+ }
+ static constexpr float8_p3109_p infinity() {
+ return float8_p3109_p ::FromRep(0x7f);
+ }
+ static constexpr float8_p3109_p quiet_NaN() {
+ return float8_p3109_p ::FromRep(0x80);
+ }
+ static constexpr float8_p3109_p signaling_NaN() {
+ return float8_p3109_p ::FromRep(0x80);
+ }
+ static constexpr float8_p3109_p denorm_min() {
+ return float8_p3109_p ::FromRep(0x01);
+ }
+};
+
} // namespace float8_internal
} // namespace ml_dtypes
@@ -788,6 +888,11 @@ struct numeric_limits {};
+
} // namespace std
namespace ml_dtypes {
@@ -839,6 +944,16 @@ constexpr inline bool (isnan)(const float8_e5m2fnuz& a) {
return a.rep() == 0x80;
}
+template & a) {
+ return a.rep() == 0x80;
+}
+
+template abs(const float8_p3109_p & a) {
+ return isnan(a) ? a : float8_p3109_p ::FromRep(a.rep() & 0x7F);
+}
+
template ;
} // namespace ml_dtypes
@@ -1345,6 +1468,12 @@ EIGEN_DEVICE_FUNC inline bool isinf_impl & x) {
+ return ml_dtypes::float8_internal::isinf(x);
+}
+
+
template <>
EIGEN_DEVICE_FUNC inline bool isnan_impl & x) {
+ return ml_dtypes::float8_internal::isnan(x);
+}
+
+
template <>
EIGEN_DEVICE_FUNC inline bool isfinite_impl & x) {
+ return ml_dtypes::float8_internal::isfinite(x);
+}
+
} // namespace internal
} // namespace Eigen
diff --git a/ml_dtypes/tests/CMakeLists.txt b/ml_dtypes/tests/CMakeLists.txt
new file mode 100644
index 00000000..3dd07e94
--- /dev/null
+++ b/ml_dtypes/tests/CMakeLists.txt
@@ -0,0 +1,42 @@
+cmake_minimum_required(VERSION 3.14)
+project(my_project)
+
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+
+set(GOOGLETEST_DOWNLOAD_URL https://github.com/google/googletest/archive/refs/tags/v1.12.0.zip)
+
+include(FetchContent)
+FetchContent_Declare(
+ googletest
+ URL ${GOOGLETEST_DOWNLOAD_URL}
+)
+# For Windows: Prevent overriding the parent project's compiler/linker settings
+set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
+FetchContent_MakeAvailable(googletest)
+
+set(ABSL_PROPAGATE_CXX_STD ON)
+set(ABSL_GOOGLETEST_DOWNLOAD_URL ${GOOGLETEST_DOWNLOAD_URL})
+add_subdirectory(abseil-cpp)
+
+enable_testing()
+
+add_executable(
+ float8_test
+ float8_test.cc
+)
+target_include_directories(float8_test PUBLIC
+ ..
+ ../..
+ ../../third_party/eigen
+)
+
+target_link_libraries(
+ float8_test
+ GTest::gtest_main
+ GTest::gmock_main
+ absl::strings
+)
+
+include(GoogleTest)
+gtest_discover_tests(float8_test)
diff --git a/ml_dtypes/tests/custom_float_test.py b/ml_dtypes/tests/custom_float_test.py
index d71ae8b1..176134a6 100644
--- a/ml_dtypes/tests/custom_float_test.py
+++ b/ml_dtypes/tests/custom_float_test.py
@@ -35,6 +35,9 @@
float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz
float8_e5m2 = ml_dtypes.float8_e5m2
float8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz
+float8_p3109_p3 = ml_dtypes.float8_p3109_p3
+float8_p3109_p4 = ml_dtypes.float8_p3109_p4
+float8_p3109_p5 = ml_dtypes.float8_p3109_p5
@contextlib.contextmanager
@@ -105,6 +108,9 @@ def dtype_has_inf(dtype):
float8_e4m3fnuz,
float8_e5m2,
float8_e5m2fnuz,
+ float8_p3109_p3,
+ float8_p3109_p4,
+ float8_p3109_p5,
]
# Values that should round trip exactly to float and back.
@@ -118,7 +124,7 @@ def dtype_has_inf(dtype):
-0.5,
float(ml_dtypes.finfo(dtype).eps),
1.0 + float(ml_dtypes.finfo(dtype).eps),
- 1.0 - float(ml_dtypes.finfo(dtype).eps),
+ 1.0 - float(ml_dtypes.finfo(dtype).eps), # TODO: should be epsneg?
-1.0 - float(ml_dtypes.finfo(dtype).eps),
-1.0 + float(ml_dtypes.finfo(dtype).eps),
3.5,
@@ -159,6 +165,21 @@ def dtype_has_inf(dtype):
range(1 << n, 2 << n, 1 << max(0, n - 2)) for n in range(16)
)
),
+ float8_p3109_p3: list(
+ itertools.chain.from_iterable(
+ range(1 << n, 2 << n, 1 << max(0, n - 2)) for n in range(16)
+ )
+ )[:-1],
+ float8_p3109_p4: list(
+ itertools.chain.from_iterable(
+ range(1 << n, 2 << n, 1 << max(0, n - 3)) for n in range(8)
+ )
+ )[:-1],
+ float8_p3109_p5: list(
+ itertools.chain.from_iterable(
+ range(1 << n, 2 << n, 1 << max(0, n - 4)) for n in range(4)
+ )
+ )[:-1],
}
BITS_TYPE = {
@@ -168,6 +189,9 @@ def dtype_has_inf(dtype):
float8_e4m3fnuz: np.uint8,
float8_e5m2: np.uint8,
float8_e5m2fnuz: np.uint8,
+ float8_p3109_p3: np.uint8,
+ float8_p3109_p4: np.uint8,
+ float8_p3109_p5: np.uint8,
}
@@ -224,19 +248,15 @@ def testRoundTripToNumpy(self, float_type):
np.longdouble,
]:
with self.subTest(dtype.__name__):
- for v in FLOAT_VALUES[float_type]:
- np.testing.assert_equal(dtype(v), dtype(float_type(dtype(v))))
+ vals = FLOAT_VALUES[float_type]
+ for v in vals:
np.testing.assert_equal(dtype(v), dtype(float_type(dtype(v))))
np.testing.assert_equal(
dtype(v), dtype(float_type(np.array(v, dtype)))
)
if dtype != float_type:
- np.testing.assert_equal(
- np.array(FLOAT_VALUES[float_type], dtype),
- float_type(np.array(FLOAT_VALUES[float_type], dtype)).astype(
- dtype
- ),
- )
+ npvals = np.array(vals, dtype)
+ np.testing.assert_equal(npvals, float_type(npvals).astype(dtype))
def testCastBetweenCustomTypes(self, float_type):
for dtype in FLOAT_DTYPES:
@@ -610,9 +630,9 @@ def testArray(self, float_type):
self.assertTrue((x == x).all())
def testComparisons(self, float_type):
- x = np.array([30, 7, -30], dtype=np.float32)
+ x = np.array([15, 7, -15], dtype=np.float32)
bx = x.astype(float_type)
- y = np.array([17, 7, 0], dtype=np.float32)
+ y = np.array([13, 7, 0], dtype=np.float32)
by = y.astype(float_type)
np.testing.assert_equal(x == y, bx == by)
np.testing.assert_equal(x != y, bx != by)
@@ -729,8 +749,8 @@ def testArange(self, float_type):
np.arange(-0.0, -2.0, -0.25, dtype=float_type),
)
np.testing.assert_equal(
- np.arange(-16.0, 16.0, 2.0, dtype=np.float32).astype(float_type),
- np.arange(-16.0, 16.0, 2.0, dtype=float_type),
+ np.arange(-14.0, 14.0, 2.0, dtype=np.float32).astype(float_type),
+ np.arange(-14.0, 14.0, 2.0, dtype=float_type),
)
@ignore_warning(category=RuntimeWarning, message="invalid value encountered")
diff --git a/ml_dtypes/tests/finfo_test.py b/ml_dtypes/tests/finfo_test.py
index 855c00ba..343b8bda 100644
--- a/ml_dtypes/tests/finfo_test.py
+++ b/ml_dtypes/tests/finfo_test.py
@@ -24,6 +24,9 @@
ml_dtypes.float8_e4m3fnuz,
ml_dtypes.float8_e5m2,
ml_dtypes.float8_e5m2fnuz,
+ ml_dtypes.float8_p3109_p3,
+ ml_dtypes.float8_p3109_p4,
+ ml_dtypes.float8_p3109_p5,
]
DTYPES_WITH_NO_INFINITY = [
diff --git a/ml_dtypes/tests/float8_test.cc b/ml_dtypes/tests/float8_test.cc
index 960f89af..fb3f2cce 100644
--- a/ml_dtypes/tests/float8_test.cc
+++ b/ml_dtypes/tests/float8_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include