Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit 0ae3839

Browse files
authored
[MLIR] Add f8E4M3 IEEE 754 type (#97118)
This PR adds `f8E4M3` type to mlir. `f8E4M3` type follows IEEE 754 convention ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 = −6 - Precision specifies the total number of bits used for the significand (mantisa), including implicit leading integer bit = 3 + 1 = 4 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 7 - Min exp (unbiased): -6 - Infinities (+/-): S.1111.000 - Zeros (+/-): S.0000.000 - NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111} - Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240 - Min normal number: S.0001.000 = +/-2^(-6) - Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7 - Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9) ``` Related PRs: - [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type
1 parent be03b0e commit 0ae3839

File tree

5 files changed

+60
-1
lines changed

5 files changed

+60
-1
lines changed

mlir/include/mlir-c/BuiltinTypes.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type);
8989
/// context.
9090
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx);
9191

92+
/// Returns the typeID of an Float8E4M3 type.
93+
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3TypeGetTypeID(void);
94+
95+
/// Checks whether the given type is an f8E4M3 type.
96+
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3(MlirType type);
97+
98+
/// Creates an f8E4M3 type in the given context. The type is owned by the
99+
/// context.
100+
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3TypeGet(MlirContext ctx);
101+
92102
/// Returns the typeID of an Float8E4M3FN type.
93103
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNTypeGetTypeID(void);
94104

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ class PyFloat8E4M3FNType
143143
}
144144
};
145145

146-
/// Floating Point Type subclass - Float8M5E2Type.
146+
/// Floating Point Type subclass - Float8E5M2Type.
147147
class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
148148
public:
149149
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
@@ -163,6 +163,26 @@ class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
163163
}
164164
};
165165

166+
/// Floating Point Type subclass - Float8E4M3Type.
167+
class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
168+
public:
169+
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
170+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
171+
mlirFloat8E4M3TypeGetTypeID;
172+
static constexpr const char *pyClassName = "Float8E4M3Type";
173+
using PyConcreteType::PyConcreteType;
174+
175+
static void bindDerived(ClassTy &c) {
176+
c.def_static(
177+
"get",
178+
[](DefaultingPyMlirContext context) {
179+
MlirType t = mlirFloat8E4M3TypeGet(context->get());
180+
return PyFloat8E4M3Type(context->getRef(), t);
181+
},
182+
py::arg("context") = py::none(), "Create a float8_e4m3 type.");
183+
}
184+
};
185+
166186
/// Floating Point Type subclass - Float8E4M3FNUZ.
167187
class PyFloat8E4M3FNUZType
168188
: public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
@@ -840,6 +860,7 @@ void mlir::python::populateIRTypes(py::module &m) {
840860
PyIndexType::bind(m);
841861
PyFloat8E4M3FNType::bind(m);
842862
PyFloat8E5M2Type::bind(m);
863+
PyFloat8E4M3Type::bind(m);
843864
PyFloat8E4M3FNUZType::bind(m);
844865
PyFloat8E4M3B11FNUZType::bind(m);
845866
PyFloat8E5M2FNUZType::bind(m);

mlir/lib/CAPI/IR/BuiltinTypes.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,18 @@ MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
9797
return wrap(FloatType::getFloat8E5M2(unwrap(ctx)));
9898
}
9999

100+
MlirTypeID mlirFloat8E4M3TypeGetTypeID() {
101+
return wrap(Float8E4M3Type::getTypeID());
102+
}
103+
104+
bool mlirTypeIsAFloat8E4M3(MlirType type) {
105+
return unwrap(type).isFloat8E4M3();
106+
}
107+
108+
MlirType mlirFloat8E4M3TypeGet(MlirContext ctx) {
109+
return wrap(FloatType::getFloat8E4M3(unwrap(ctx)));
110+
}
111+
100112
MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() {
101113
return wrap(Float8E4M3FNType::getTypeID());
102114
}

mlir/python/mlir/_mlir_libs/_mlir/ir.pyi

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ __all__ = [
123123
"Float8E4M3B11FNUZType",
124124
"Float8E4M3FNType",
125125
"Float8E4M3FNUZType",
126+
"Float8E4M3Type",
126127
"Float8E5M2FNUZType",
127128
"Float8E5M2Type",
128129
"FloatAttr",
@@ -1575,6 +1576,19 @@ class Float8E4M3FNUZType(FloatType):
15751576
@property
15761577
def typeid(self) -> TypeID: ...
15771578

1579+
class Float8E4M3Type(FloatType):
1580+
static_typeid: ClassVar[TypeID]
1581+
@staticmethod
1582+
def get(context: Optional[Context] = None) -> Float8E4M3Type:
1583+
"""
1584+
Create a float8_e4m3 type.
1585+
"""
1586+
@staticmethod
1587+
def isinstance(other: Type) -> bool: ...
1588+
def __init__(self, cast_from_type: Type) -> None: ...
1589+
@property
1590+
def typeid(self) -> TypeID: ...
1591+
15781592
class Float8E5M2FNUZType(FloatType):
15791593
static_typeid: ClassVar[TypeID]
15801594
@staticmethod

mlir/python/mlir/extras/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
F64Type,
1515
Float8E4M3B11FNUZType,
1616
Float8E4M3FNType,
17+
Float8E4M3Type,
1718
Float8E5M2Type,
1819
FunctionType,
1920
IndexType,
@@ -68,6 +69,7 @@ def ui(width):
6869
bf16 = lambda: BF16Type.get()
6970

7071
f8E5M2 = lambda: Float8E5M2Type.get()
72+
f8E4M3 = lambda: Float8E4M3Type.get()
7173
f8E4M3FN = lambda: Float8E4M3FNType.get()
7274
f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
7375

0 commit comments

Comments
 (0)