This repository was archived by the owner on Oct 11, 2025. It is now read-only.
File tree Expand file tree Collapse file tree 1 file changed +12
-0
lines changed Expand file tree Collapse file tree 1 file changed +12
-0
lines changed Original file line number Diff line number Diff line change @@ -37,6 +37,11 @@ class BF16(ctypes.Structure):
3737
3838 _fields_ = [("bf16" , ctypes .c_int16 )]
3939
40+ class F8E5M2 (ctypes .Structure ):
41+ """A ctype representation for MLIR's Float8E5M2."""
42+
43+ _fields_ = [("f8E5M2" , ctypes .c_int8 )]
44+
4045
4146# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
4247def as_ctype (dtp ):
@@ -49,6 +54,8 @@ def as_ctype(dtp):
4954 return F16
5055 if ml_dtypes is not None and dtp == ml_dtypes .bfloat16 :
5156 return BF16
57+ if ml_dtypes is not None and dtp == ml_dtypes .float8_e5m2 :
58+ return F8E5M2
5259 return np .ctypeslib .as_ctypes_type (dtp )
5360
5461
@@ -65,6 +72,11 @@ def to_numpy(array):
6572 ), f"bfloat16 requires the ml_dtypes package, please run:\n \n pip install ml_dtypes\n "
6673 if array .dtype == BF16 :
6774 return array .view ("bfloat16" )
75+ assert not (
76+ array .dtype == F8E5M2 and ml_dtypes is None
77+ ), f"float8_e5m2 requires the ml_dtypes package, please run:\n \n pip install ml_dtypes\n "
78+ if array .dtype == F8E5M2 :
79+ return array .view ("float8_e5m2" )
6880 return array
6981
7082
You can’t perform that action at this time.
0 commit comments