Skip to content

Commit ed8c91e

Browse files
authored
Add f8E4M3 and f8E3M4 types support (#2482)
This PR adds f8E4M3 and f8E3M4 types support. f8E4M3 and f8E3M4 types follow IEEE 754 convention. ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 = −6 - Precision specifies the total number of bits used for the significand (mantisa), including implicit leading integer bit = 3 + 1 = 4 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 7 - Min exp (unbiased): -6 - Infinities (+/-): S.1111.000 - Zeros (+/-): S.0000.000 - NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111} - Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240 - Min normal number: S.0001.000 = +/-2^(-6) - Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7 - Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9) ``` ```c f8E3M4 (IEEE 754) - Exponent bias: 3 - Maximum stored exponent value: 6 (binary 110) - Maximum unbiased exponent value: 6 - 3 = 3 - Minimum stored exponent value: 1 (binary 001) - Minimum unbiased exponent value: 1 − 3 = −2 - Precision specifies the total number of bits used for the significand (mantissa), including implicit leading integer bit = 4 + 1 = 5 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 3 - Min exp (unbiased): -2 - Infinities (+/-): S.111.0000 - Zeros (+/-): S.000.0000 - NaNs: S.111.{0,1}⁴ except S.111.0000 - Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5 - Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2) - Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6) - Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6) ``` Related PRs: - LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged) - LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged) - LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged) - LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged) - StableHLO [PR-2486](#2486) [RFC] Add f8E4M3 and f8E3M4 types support - ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged) - ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged) - XLA [PR-16585](openxla/xla#16585) Add support for float8_e4m3
1 parent dd97e7a commit ed8c91e

19 files changed

+3124
-75
lines changed

docs/spec.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,8 @@ BooleanType ::= 'i1'
245245
IntegerType ::= SignedIntegerType | UnsignedIntegerType
246246
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
247247
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
248-
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
249-
| 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
248+
FloatType ::= 'f8E3M4' | 'f8E4M3' | 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ'
249+
| 'f8E5M2' | 'f8E5M2FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
250250
TensorFloat32 ::= 'tf32'
251251
ComplexType ::= 'complex' '<' ComplexElementType '>'
252252
ComplexElementType ::= 'f32' | 'f64'
@@ -265,6 +265,8 @@ values of type `tensor<T>`).
265265
inclusive, and unsigned `uiN` types represent integer values from `0` to
266266
`2^N-1` inclusive.
267267
* **Floating-point types** can be one of the following:
268+
* `f8E3M4`, `f8E4M3` and `f8E5M2` 8-bit floating point numbers following
269+
IEEE-754 conventions.
268270
* `f8E4M3FN` and `f8E5M2` types corresponding to respectively the
269271
`E4M3` and `E5M2` encodings of the FP8 format described in
270272
[FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433).

stablehlo/dialect/Base.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,8 +647,9 @@ FailureOr<KnownDotAlgorithm> getKnownDotAlgorithm(
647647
if (lhsComponentCount != 1 || rhsComponentCount != 1) return failure();
648648

649649
auto isAnyF8 = [](Type t) {
650-
return llvm::isa<Float8E4M3FNType, Float8E5M2Type, Float8E4M3FNUZType,
651-
Float8E4M3B11FNUZType, Float8E5M2FNUZType>(t);
650+
return llvm::isa<Float8E3M4Type, Float8E4M3Type, Float8E4M3FNType,
651+
Float8E5M2Type, Float8E4M3FNUZType, Float8E4M3B11FNUZType,
652+
Float8E5M2FNUZType>(t);
652653
};
653654
if (isAnyF8(lhsPrecisionType) && isAnyF8(rhsPrecisionType) &&
654655
accumulationType.isF32() && numPrimitiveOperations == 1) {

stablehlo/dialect/Base.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def HLO_SInt : SignlessIntOfWidths<[2, 4, 8, 16, 32, 64]>;
4242
def HLO_UInt : UnsignedIntOfWidths<[2, 4, 8, 16, 32, 64]>;
4343
def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>;
4444

45-
def HLO_Float : AnyTypeOf<[F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ, F8E5M2,
46-
F8E5M2FNUZ, F16, F32, F64, BF16]>;
45+
def HLO_Float : AnyTypeOf<[F8E3M4, F8E4M3, F8E4M3FN, F8E4M3FNUZ, F8E4M3B11FNUZ,
46+
F8E5M2, F8E5M2FNUZ, F16, F32, F64, BF16]>;
4747
def HLO_Float32Or64 : AnyTypeOf<[F32, F64]>;
4848

4949
def HLO_Complex : Complex<AnyTypeOf<[F32, F64]>>;

stablehlo/dialect/Version.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class Version {
3838
static FailureOr<Version> fromString(llvm::StringRef versionRef);
3939

4040
/// Return a Version representing the current VHLO dialect version.
41-
static Version getCurrentVersion() { return Version(1, 6, 4); }
41+
static Version getCurrentVersion() { return Version(1, 7, 0); }
4242

4343
/// Return a Version representing the minimum supported VHLO dialect version.
4444
static Version getMinimumVersion() { return Version(0, 9, 0); }

stablehlo/dialect/VhloBytecode.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ enum AttributeCode {
189189
/// location is updated.
190190
enum TypeCode {
191191
// TO ADD TYPE: Add an enum value with doc string for new type.
192-
// Next available code: 35
192+
// Next available code: 37
193193

194194
/// BooleanV1Type {
195195
/// }
@@ -216,6 +216,14 @@ enum TypeCode {
216216
/// }
217217
kFloatF64V1Type = 5,
218218

219+
/// FloatF8E3M4V1Type {
220+
/// }
221+
kFloatF8E3M4V1Type = 36,
222+
223+
/// FloatF8E4M3V1Type {
224+
/// }
225+
kFloatF8E4M3V1Type = 35,
226+
219227
/// FloatF8E4M3FNV1Type {
220228
/// }
221229
kFloatF8E4M3FNV1Type = 6,
@@ -698,9 +706,11 @@ const llvm::fltSemantics &getFloatSemantics(Type type) {
698706
if (isa<FloatF16V1Type>(type)) return APFloat::IEEEhalf();
699707
if (isa<FloatF32V1Type>(type)) return APFloat::IEEEsingle();
700708
if (isa<FloatF64V1Type>(type)) return APFloat::IEEEdouble();
709+
if (isa<FloatF8E3M4V1Type>(type)) return APFloat::Float8E3M4();
701710
if (isa<FloatF8E4M3FNUZV1Type>(type)) return APFloat::Float8E4M3FNUZ();
702711
if (isa<FloatF8E4M3B11FNUZV1Type>(type)) return APFloat::Float8E4M3B11FNUZ();
703712
if (isa<FloatF8E4M3FNV1Type>(type)) return APFloat::Float8E4M3FN();
713+
if (isa<FloatF8E4M3V1Type>(type)) return APFloat::Float8E4M3();
704714
if (isa<FloatF8E5M2FNUZV1Type>(type)) return APFloat::Float8E5M2FNUZ();
705715
if (isa<FloatF8E5M2V1Type>(type)) return APFloat::Float8E5M2();
706716
if (isa<FloatTF32V1Type>(type)) return APFloat::FloatTF32();
@@ -968,6 +978,8 @@ Type VhloBytecodeInterface::readType(DialectBytecodeReader &reader) const {
968978
return FloatF64V1Type::get(getContext());
969979
case vhlo_encoding::kFloatF8E5M2V1Type:
970980
return FloatF8E5M2V1Type::get(getContext());
981+
case vhlo_encoding::kFloatF8E4M3V1Type:
982+
return FloatF8E4M3V1Type::get(getContext());
971983
case vhlo_encoding::kFloatF8E4M3FNV1Type:
972984
return FloatF8E4M3FNV1Type::get(getContext());
973985
case vhlo_encoding::kFloatF8E5M2FNUZV1Type:
@@ -976,6 +988,8 @@ Type VhloBytecodeInterface::readType(DialectBytecodeReader &reader) const {
976988
return FloatF8E4M3FNUZV1Type::get(getContext());
977989
case vhlo_encoding::kFloatF8E4M3B11FNUZV1Type:
978990
return FloatF8E4M3B11FNUZV1Type::get(getContext());
991+
case vhlo_encoding::kFloatF8E3M4V1Type:
992+
return FloatF8E3M4V1Type::get(getContext());
979993
case vhlo_encoding::kFloatTF32V1Type:
980994
return FloatTF32V1Type::get(getContext());
981995
case vhlo_encoding::kFunctionV1Type:
@@ -1060,6 +1074,14 @@ LogicalResult VhloBytecodeInterface::writeType(
10601074
LOG_WRITE_CALL;
10611075
return writer.writeVarInt(vhlo_encoding::kFloatF64V1Type), success();
10621076
})
1077+
.Case([&](FloatF8E3M4V1Type) {
1078+
LOG_WRITE_CALL;
1079+
return writer.writeVarInt(vhlo_encoding::kFloatF8E3M4V1Type), success();
1080+
})
1081+
.Case([&](FloatF8E4M3V1Type) {
1082+
LOG_WRITE_CALL;
1083+
return writer.writeVarInt(vhlo_encoding::kFloatF8E4M3V1Type), success();
1084+
})
10631085
.Case([&](FloatF8E4M3FNV1Type) {
10641086
LOG_WRITE_CALL;
10651087
return writer.writeVarInt(vhlo_encoding::kFloatF8E4M3FNV1Type),

stablehlo/dialect/VhloDialect.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def VHLO_Dialect : Dialect {
4545
1.4.0: Add `tan` op to StableHLO opset.
4646
1.5.0: Make collective ops (`all_reduce`, `all_gather`, `all_to_all`) variadic.
4747
1.6.0: Add DotAlgorithm specificaiton to `dot_general`.
48+
1.7.0: Introduce `f8E4M3` and `f8E3M4` types.
4849
}];
4950

5051
let useDefaultAttributePrinterParser = 0;

stablehlo/dialect/VhloTypes.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,12 @@ void VhloTypeConverter::addBuiltinToVhloConversions() {
8484
[&](Float32Type type) { return FloatF32V1Type::get(type.getContext()); });
8585
addConversion(
8686
[&](Float64Type type) { return FloatF64V1Type::get(type.getContext()); });
87+
addConversion([&](Float8E3M4Type type) {
88+
return FloatF8E3M4V1Type::get(type.getContext());
89+
});
90+
addConversion([&](Float8E4M3Type type) {
91+
return FloatF8E4M3V1Type::get(type.getContext());
92+
});
8793
addConversion([&](Float8E4M3FNType type) {
8894
return FloatF8E4M3FNV1Type::get(type.getContext());
8995
});
@@ -176,6 +182,12 @@ void VhloTypeConverter::addVhloToBuiltinConversions() {
176182
[&](FloatF32V1Type type) { return Float32Type::get(type.getContext()); });
177183
addConversion(
178184
[&](FloatF64V1Type type) { return Float64Type::get(type.getContext()); });
185+
addConversion([&](FloatF8E3M4V1Type type) {
186+
return Float8E3M4Type::get(type.getContext());
187+
});
188+
addConversion([&](FloatF8E4M3V1Type type) {
189+
return Float8E4M3Type::get(type.getContext());
190+
});
179191
addConversion([&](FloatF8E4M3FNV1Type type) {
180192
return Float8E4M3FNType::get(type.getContext());
181193
});

stablehlo/dialect/VhloTypes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ def VHLO_FloatF32V1 : VHLO_TypeDef<"FloatF32V1", "f32_v1", "0.9.0", "current">;
7979
// Corresponds to the 'f64' FloatType from the StableHLO spec.
8080
def VHLO_FloatF64V1 : VHLO_TypeDef<"FloatF64V1","f64_v1", "0.9.0", "current">;
8181

82+
// Corresponds to the 'f8E3M4' FloatType from the StableHLO spec.
83+
def VHLO_FloatF8E3M4V1 : VHLO_TypeDef<"FloatF8E3M4V1", "f8E3M4_v1", "1.7.0", "current">;
84+
85+
// Corresponds to the 'f8E4M3' FloatType from the StableHLO spec.
86+
def VHLO_FloatF8E4M3V1 : VHLO_TypeDef<"FloatF8E4M3V1", "f8E4M3_v1", "1.7.0", "current">;
87+
8288
// Corresponds to the 'f8E4M3FN' FloatType from the StableHLO spec.
8389
def VHLO_FloatF8E4M3FNV1 : VHLO_TypeDef<"FloatF8E4M3FNV1", "f8E4M3FN_v1", "0.9.0", "current">;
8490

stablehlo/reference/Tensor.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,21 @@ Element Tensor::get(const Index &index) const {
118118
getSizeInBytes(elementType) * flattenIndex(getShape(), index);
119119

120120
// Handle floating-point types.
121+
if (elementType.isFloat8E3M4()) {
122+
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
123+
return Element(elementType, APFloat(llvm::APFloatBase::Float8E3M4(),
124+
APInt(8, *elementData)));
125+
}
121126
if (elementType.isFloat8E4M3B11FNUZ()) {
122127
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
123128
return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3B11FNUZ(),
124129
APInt(8, *elementData)));
125130
}
131+
if (elementType.isFloat8E4M3()) {
132+
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
133+
return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3(),
134+
APInt(8, *elementData)));
135+
}
126136
if (elementType.isFloat8E4M3FN()) {
127137
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
128138
return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3FN(),
@@ -252,7 +262,8 @@ void Tensor::set(const Index &index, const Element &element) {
252262
getSizeInBytes(elementType) * flattenIndex(getShape(), index);
253263

254264
// Handle floating-point types.
255-
if (elementType.isFloat8E4M3B11FNUZ() || elementType.isFloat8E4M3FN() ||
265+
if (elementType.isFloat8E3M4() || elementType.isFloat8E4M3B11FNUZ() ||
266+
elementType.isFloat8E4M3() || elementType.isFloat8E4M3FN() ||
256267
elementType.isFloat8E4M3FNUZ() || elementType.isFloat8E5M2() ||
257268
elementType.isFloat8E5M2FNUZ()) {
258269
auto elementData = reinterpret_cast<uint8_t *>(elementPtr);
@@ -446,17 +457,18 @@ Tensor makeTensor(DenseElementsAttr attr) {
446457
auto elementType = type.getElementType();
447458

448459
// Handle floating-point types.
449-
if (elementType.isFloat8E4M3B11FNUZ() || elementType.isFloat8E4M3FN() ||
460+
if (elementType.isFloat8E3M4() || elementType.isFloat8E4M3B11FNUZ() ||
461+
elementType.isFloat8E4M3() || elementType.isFloat8E4M3FN() ||
450462
elementType.isFloat8E4M3FNUZ() || elementType.isFloat8E5M2() ||
451463
elementType.isFloat8E5M2FNUZ()) {
452464
auto floatValues = llvm::map_to_vector(
453465
attr.getValues<APFloat>(), [&](APFloat value) -> uint8_t {
454466
return value.bitcastToAPInt().getZExtValue();
455467
});
456468

457-
// For f8E4M3B11FNUZ, f8E4M3FN, f8E4M3FNUZ, f8E5M2, and f8E5M2FNUZ
458-
// floating-point types, we use uint8_t as their storage type because there
459-
// are no builtin types for those.
469+
// For f8E3M4, f8E4M3, f8E4M3FN, f8E4M3FNUZ, f8E4M3B11FNUZ, f8E5M2, and
470+
// f8E5M2FNUZ floating-point types, we use uint8_t as their storage type
471+
// because there are no builtin types for those.
460472
return Tensor(type, HeapAsmResourceBlob::allocateAndCopyInferAlign<uint8_t>(
461473
floatValues));
462474
}

stablehlo/reference/Types.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ bool isSupportedIntegerType(Type type) {
4848
}
4949

5050
bool isSupportedFloatType(Type type) {
51-
return type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() ||
51+
return type.isFloat8E3M4() || type.isFloat8E4M3B11FNUZ() ||
52+
type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
5253
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() ||
5354
type.isFloat8E5M2FNUZ() || type.isF16() || type.isBF16() ||
5455
type.isF32() || type.isF64();

0 commit comments

Comments
 (0)