Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit a4d175b

Browse files
authored
[MLIR] Add f8E8M0FNU type (#111028)
This PR adds `f8E8M0FNU` type to MLIR. `f8E8M0FNU` type is proposed in [OpenCompute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). It defines a 8-bit floating point number with bit layout S0E8M0. Unlike IEEE-754 types, there are no infinity, denormals, zeros or negative values. ```c f8E8M0FNU - Exponent bias: 127 - Maximum stored exponent value: 254 (binary 1111'1110) - Maximum unbiased exponent value: 254 - 127 = 127 - Minimum stored exponent value: 0 (binary 0000'0000) - Minimum unbiased exponent value: 0 − 127 = -127 - Doesn't have zero - Doesn't have infinity - NaN is encoded as binary 1111'1111 Additional details: - Zeros cannot be represented - Negative values cannot be represented - Mantissa is always 1 ``` Related PRs: - [PR-107127](llvm/llvm-project#107127) [APFloat] Add APFloat support for E8M0 type - [PR-105573](llvm/llvm-project#105573) [MLIR] Add f6E3M2FN type - was used as a template for this PR - [PR-107999](llvm/llvm-project#107999) [MLIR] Add f6E2M3FN type - [PR-108877](llvm/llvm-project#108877) [MLIR] Add f4E2M1FN type
1 parent ff34b33 commit a4d175b

File tree

5 files changed

+60
-0
lines changed

5 files changed

+60
-0
lines changed

mlir/include/mlir-c/BuiltinTypes.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E3M4(MlirType type);
179179
/// context.
180180
MLIR_CAPI_EXPORTED MlirType mlirFloat8E3M4TypeGet(MlirContext ctx);
181181

182+
/// Returns the typeID of an Float8E8M0FNU type.
183+
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID(void);
184+
185+
/// Checks whether the given type is an f8E8M0FNU type.
186+
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E8M0FNU(MlirType type);
187+
188+
/// Creates an f8E8M0FNU type in the given context. The type is owned by the
189+
/// context.
190+
MLIR_CAPI_EXPORTED MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx);
191+
182192
/// Returns the typeID of an BFloat16 type.
183193
MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void);
184194

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,27 @@ class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
331331
}
332332
};
333333

334+
/// Floating Point Type subclass - Float8E8M0FNUType.
335+
class PyFloat8E8M0FNUType
336+
: public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
337+
public:
338+
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
339+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
340+
mlirFloat8E8M0FNUTypeGetTypeID;
341+
static constexpr const char *pyClassName = "Float8E8M0FNUType";
342+
using PyConcreteType::PyConcreteType;
343+
344+
static void bindDerived(ClassTy &c) {
345+
c.def_static(
346+
"get",
347+
[](DefaultingPyMlirContext context) {
348+
MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
349+
return PyFloat8E8M0FNUType(context->getRef(), t);
350+
},
351+
py::arg("context") = py::none(), "Create a float8_e8m0fnu type.");
352+
}
353+
};
354+
334355
/// Floating Point Type subclass - BF16Type.
335356
class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
336357
public:
@@ -953,6 +974,7 @@ void mlir::python::populateIRTypes(py::module &m) {
953974
PyFloat8E4M3B11FNUZType::bind(m);
954975
PyFloat8E5M2FNUZType::bind(m);
955976
PyFloat8E3M4Type::bind(m);
977+
PyFloat8E8M0FNUType::bind(m);
956978
PyBF16Type::bind(m);
957979
PyF16Type::bind(m);
958980
PyTF32Type::bind(m);

mlir/lib/CAPI/IR/BuiltinTypes.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,18 @@ MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
205205
return wrap(FloatType::getFloat8E3M4(unwrap(ctx)));
206206
}
207207

208+
MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID() {
209+
return wrap(Float8E8M0FNUType::getTypeID());
210+
}
211+
212+
bool mlirTypeIsAFloat8E8M0FNU(MlirType type) {
213+
return unwrap(type).isFloat8E8M0FNU();
214+
}
215+
216+
MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx) {
217+
return wrap(FloatType::getFloat8E8M0FNU(unwrap(ctx)));
218+
}
219+
208220
MlirTypeID mlirBFloat16TypeGetTypeID() {
209221
return wrap(BFloat16Type::getTypeID());
210222
}

mlir/python/mlir/_mlir_libs/_mlir/ir.pyi

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ __all__ = [
117117
"Float8E4M3Type",
118118
"Float8E5M2FNUZType",
119119
"Float8E5M2Type",
120+
"Float8E8M0FNUType",
120121
"FloatAttr",
121122
"FloatTF32Type",
122123
"FloatType",
@@ -1660,6 +1661,19 @@ class Float8E5M2Type(FloatType):
16601661
@property
16611662
def typeid(self) -> TypeID: ...
16621663

1664+
class Float8E8M0FNUType(FloatType):
1665+
static_typeid: ClassVar[TypeID]
1666+
@staticmethod
1667+
def get(context: Context | None = None) -> Float8E8M0FNUType:
1668+
"""
1669+
Create a float8_e8m0fnu type.
1670+
"""
1671+
@staticmethod
1672+
def isinstance(other: Type) -> bool: ...
1673+
def __init__(self, cast_from_type: Type) -> None: ...
1674+
@property
1675+
def typeid(self) -> TypeID: ...
1676+
16631677
class FloatAttr(Attribute):
16641678
static_typeid: ClassVar[TypeID]
16651679
@staticmethod

mlir/python/mlir/extras/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Float8E4M3FNType,
2121
Float8E4M3Type,
2222
Float8E5M2Type,
23+
Float8E8M0FNUType,
2324
FunctionType,
2425
IndexType,
2526
IntegerType,
@@ -80,6 +81,7 @@ def ui(width):
8081
f4E2M1FN = lambda: Float4E2M1FNType.get()
8182
f6E2M3FN = lambda: Float6E2M3FNType.get()
8283
f6E3M2FN = lambda: Float6E3M2FNType.get()
84+
f8E8M0FNU = lambda: Float8E8M0FNUType.get()
8385

8486
none = lambda: NoneType.get()
8587

0 commit comments

Comments
 (0)