Skip to content

Commit 1833c0c

Browse files
committed
Add float8_e4m3
1 parent b157c19 commit 1833c0c

File tree

9 files changed

+363
-9
lines changed

9 files changed

+363
-9
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323

2424
## [Unreleased]
2525

26+
* Added new 8-bit float type following IEEE 754 convention:
27+
`ml_dtypes.float8_e4m3`.
28+
2629
## [0.4.0] - 2024-04-1
2730

2831
* Updates `ml_dtypes` for compatibility with future NumPy 2.0 release.

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
an alternative to the standard [`float16`](https://en.wikipedia.org/wiki/Half-precision_floating-point_format) format
1111
- `float8_*`: several experimental 8-bit floating point representations
1212
including:
13+
* `float8_e4m3`
1314
* `float8_e4m3b11fnuz`
1415
* `float8_e4m3fn`
1516
* `float8_e4m3fnuz`
@@ -64,6 +65,10 @@ A `bfloat16` number is a single-precision float truncated at 16 bits.
6465

6566
Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf.
6667

68+
### `float8_e4m3`
69+
70+
Exponent: 4, Mantissa: 3, bias: 7. IEEE 754, with NaN and inf.
71+
6772
### `float8_e4m3b11fnuz`
6873

6974
Exponent: 4, Mantissa: 3, bias: 11.

ml_dtypes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"__version__",
1818
"bfloat16",
1919
"finfo",
20+
"float8_e4m3",
2021
"float8_e4m3b11fnuz",
2122
"float8_e4m3fn",
2223
"float8_e4m3fnuz",
@@ -34,6 +35,7 @@
3435
from ml_dtypes._finfo import finfo
3536
from ml_dtypes._iinfo import iinfo
3637
from ml_dtypes._ml_dtypes_ext import bfloat16
38+
from ml_dtypes._ml_dtypes_ext import float8_e4m3
3739
from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz
3840
from ml_dtypes._ml_dtypes_ext import float8_e4m3fn
3941
from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz
@@ -46,6 +48,7 @@
4648
import numpy as np
4749

4850
bfloat16: Type[np.generic]
51+
float8_e4m3: Type[np.generic]
4952
float8_e4m3b11fnuz: Type[np.generic]
5053
float8_e4m3fn: Type[np.generic]
5154
float8_e4m3fnuz: Type[np.generic]

ml_dtypes/_finfo.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Dict
1818

1919
from ml_dtypes._ml_dtypes_ext import bfloat16
20+
from ml_dtypes._ml_dtypes_ext import float8_e4m3
2021
from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz
2122
from ml_dtypes._ml_dtypes_ext import float8_e4m3fn
2223
from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz
@@ -25,6 +26,7 @@
2526
import numpy as np
2627

2728
_bfloat16_dtype = np.dtype(bfloat16)
29+
_float8_e4m3_dtype = np.dtype(float8_e4m3)
2830
_float8_e4m3b11fnuz_dtype = np.dtype(float8_e4m3b11fnuz)
2931
_float8_e4m3fn_dtype = np.dtype(float8_e4m3fn)
3032
_float8_e4m3fnuz_dtype = np.dtype(float8_e4m3fnuz)
@@ -41,6 +43,15 @@ def __init__(self):
4143
self.smallest_subnormal = bfloat16(smallest_subnormal)
4244

4345

46+
class _Float8E4m3MachArLike:
47+
48+
def __init__(self):
49+
smallest_normal = float.fromhex("0x1p-6")
50+
self.smallest_normal = float8_e4m3(smallest_normal)
51+
smallest_subnormal = float.fromhex("0x1p-9")
52+
self.smallest_subnormal = float8_e4m3(smallest_subnormal)
53+
54+
4455
class _Float8E4m3b11fnuzMachArLike:
4556

4657
def __init__(self):
@@ -135,6 +146,51 @@ def float_to_str(f):
135146
# pylint: enable=protected-access
136147
return obj
137148

149+
@staticmethod
150+
def _float8_e4m3_finfo():
151+
def float_to_str(f):
152+
return "%6.2e" % float(f)
153+
154+
tiny = float.fromhex("0x1p-6") # 1/64 min normal
155+
resolution = 0.1
156+
eps = float.fromhex("0x1p-3") # 1/8
157+
epsneg = float.fromhex("0x1p-4") # 1/16
158+
max_ = float.fromhex("0x1.Ep7") # 240 max normal
159+
160+
obj = object.__new__(np.finfo)
161+
obj.dtype = _float8_e4m3_dtype
162+
obj.bits = 8
163+
obj.eps = float8_e4m3(eps)
164+
obj.epsneg = float8_e4m3(epsneg)
165+
obj.machep = -3
166+
obj.negep = -4
167+
obj.max = float8_e4m3(max_)
168+
obj.min = float8_e4m3(-max_)
169+
obj.nexp = 4
170+
obj.nmant = 3
171+
obj.iexp = obj.nexp
172+
obj.maxexp = 8
173+
obj.minexp = -6
174+
obj.precision = 1
175+
obj.resolution = float8_e4m3(resolution)
176+
# pylint: disable=protected-access
177+
obj._machar = _Float8E4m3MachArLike()
178+
if not hasattr(obj, "tiny"):
179+
obj.tiny = float8_e4m3(tiny)
180+
if not hasattr(obj, "smallest_normal"):
181+
obj.smallest_normal = obj._machar.smallest_normal
182+
obj.smallest_subnormal = obj._machar.smallest_subnormal
183+
184+
obj._str_tiny = float_to_str(tiny)
185+
obj._str_smallest_normal = float_to_str(tiny)
186+
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
187+
obj._str_max = float_to_str(max_)
188+
obj._str_epsneg = float_to_str(epsneg)
189+
obj._str_eps = float_to_str(eps)
190+
obj._str_resolution = float_to_str(resolution)
191+
# pylint: enable=protected-access
192+
return obj
193+
138194
@staticmethod
139195
def _float8_e4m3b11fnuz_finfo():
140196
def float_to_str(f):
@@ -369,6 +425,14 @@ def __new__(cls, dtype):
369425
if _bfloat16_dtype not in cls._finfo_cache:
370426
cls._finfo_cache[_bfloat16_dtype] = cls._bfloat16_finfo()
371427
return cls._finfo_cache[_bfloat16_dtype]
428+
if (
429+
isinstance(dtype, str)
430+
and dtype == "float8_e4m3"
431+
or dtype == _float8_e4m3_dtype
432+
):
433+
if _float8_e4m3_dtype not in cls._finfo_cache:
434+
cls._finfo_cache[_float8_e4m3_dtype] = cls._float8_e4m3_finfo()
435+
return cls._finfo_cache[_float8_e4m3_dtype]
372436
if (
373437
isinstance(dtype, str)
374438
and dtype == "float8_e4m3b11fnuz"

ml_dtypes/_src/dtypes.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,20 @@ struct TypeDescriptor<bfloat16> : CustomFloatType<bfloat16> {
6060
static constexpr char kNpyDescrByteorder = '=';
6161
};
6262

63+
template <>
64+
struct TypeDescriptor<float8_e4m3> : CustomFloatType<float8_e4m3> {
65+
typedef float8_e4m3 T;
66+
static constexpr bool is_floating = true;
67+
static constexpr bool is_integral = false;
68+
static constexpr const char* kTypeName = "float8_e4m3";
69+
static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e4m3";
70+
static constexpr const char* kTpDoc = "float8_e4m3 floating-point values";
71+
// Set e4m3 kind as Void since kind=f (float) with itemsize=1 is used by e5m2
72+
static constexpr char kNpyDescrKind = 'V'; // Void
73+
static constexpr char kNpyDescrType = '7'; // '4' is reserved for e4m3fn
74+
static constexpr char kNpyDescrByteorder = '='; // Native byte order
75+
};
76+
6377
template <>
6478
struct TypeDescriptor<float8_e4m3b11fnuz>
6579
: CustomFloatType<float8_e4m3b11fnuz> {
@@ -269,6 +283,9 @@ bool Initialize() {
269283
if (!RegisterFloatDtype<bfloat16>(numpy.get())) {
270284
return false;
271285
}
286+
if (!RegisterFloatDtype<float8_e4m3>(numpy.get())) {
287+
return false;
288+
}
272289
if (!RegisterFloatDtype<float8_e4m3b11fnuz>(numpy.get())) {
273290
return false;
274291
}
@@ -319,6 +336,12 @@ bool Initialize() {
319336
success &= RegisterTwoWayCustomCast<float8_e5m2fnuz, float8_e4m3fn, float>();
320337
success &= RegisterTwoWayCustomCast<float8_e4m3fnuz, float8_e5m2, float>();
321338
success &= RegisterTwoWayCustomCast<float8_e5m2fnuz, float8_e5m2, float>();
339+
success &= RegisterTwoWayCustomCast<float8_e4m3, bfloat16, float>();
340+
success &= RegisterTwoWayCustomCast<float8_e4m3, float8_e4m3b11fnuz, float>();
341+
success &= RegisterTwoWayCustomCast<float8_e4m3, float8_e5m2fnuz, float>();
342+
success &= RegisterTwoWayCustomCast<float8_e4m3, float8_e4m3fnuz, float>();
343+
success &= RegisterTwoWayCustomCast<float8_e4m3, float8_e4m3fn, float>();
344+
success &= RegisterTwoWayCustomCast<float8_e4m3, float8_e5m2, float>();
322345
success &= RegisterOneWayCustomCast<int2, int4, int8_t>();
323346
success &= RegisterOneWayCustomCast<uint2, uint4, uint8_t>();
324347
return success;
@@ -349,6 +372,11 @@ extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() {
349372
return nullptr;
350373
}
351374

375+
if (PyObject_SetAttrString(m.get(), "float8_e4m3",
376+
reinterpret_cast<PyObject*>(
377+
TypeDescriptor<float8_e4m3>::type_ptr)) < 0) {
378+
return nullptr;
379+
}
352380
if (PyObject_SetAttrString(
353381
m.get(), "float8_e4m3b11fnuz",
354382
reinterpret_cast<PyObject*>(

0 commit comments

Comments
 (0)