Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit 99c0ab4

Browse files
PhrygianGatesZhicheng Xiong
andauthored
[MLIR][Python] add f8E5M2 and tests for np_to_memref (#106028)
add f8E5M2 and tests for np_to_memref --------- Co-authored-by: Zhicheng Xiong <[email protected]>
1 parent dfe6aff commit 99c0ab4

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

mlir/python/mlir/runtime/np_to_memref.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ class BF16(ctypes.Structure):
3737

3838
_fields_ = [("bf16", ctypes.c_int16)]
3939

40+
class F8E5M2(ctypes.Structure):
41+
"""A ctype representation for MLIR's Float8E5M2."""
42+
43+
_fields_ = [("f8E5M2", ctypes.c_int8)]
44+
4045

4146
# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
4247
def as_ctype(dtp):
@@ -49,6 +54,8 @@ def as_ctype(dtp):
4954
return F16
5055
if ml_dtypes is not None and dtp == ml_dtypes.bfloat16:
5156
return BF16
57+
if ml_dtypes is not None and dtp == ml_dtypes.float8_e5m2:
58+
return F8E5M2
5259
return np.ctypeslib.as_ctypes_type(dtp)
5360

5461

@@ -65,6 +72,11 @@ def to_numpy(array):
6572
), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
6673
if array.dtype == BF16:
6774
return array.view("bfloat16")
75+
assert not (
76+
array.dtype == F8E5M2 and ml_dtypes is None
77+
), f"float8_e5m2 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
78+
if array.dtype == F8E5M2:
79+
return array.view("float8_e5m2")
6880
return array
6981

7082

0 commit comments

Comments
 (0)