Skip to content

Commit f77f9a1

Browse files
committed
Added QuantizationInterface to Float8E5M2Type and Float8E4M3FNType
1 parent 138fc49 commit f77f9a1

File tree

2 files changed

+44
-12
lines changed

2 files changed

+44
-12
lines changed

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ class Builtin_CachedFloatType<string name, string mnemonic,
101101
// Float8E5M2Type
102102
//===----------------------------------------------------------------------===//
103103

104-
def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> {
104+
def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2",
105+
["QuantizationInterface"]> {
105106
let summary = "8-bit floating point with 2 bit mantissa";
106107
let description = [{
107108
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
@@ -117,6 +118,21 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> {
117118

118119
Described in: https://arxiv.org/abs/2209.05433
119120
}];
121+
122+
let extraClassDeclaration = [{
123+
/// QuantizationInterface method implementations
124+
bool isStorageSigned() const { return true; }
125+
/// Get the bit width of this 8-bit floating point type.
126+
unsigned getStorageWidth() const { return 8; }
127+
128+
/// Get default maximum value for this 8-bit floating point type.
129+
int64_t getDefaultMaximum() const { return 57344; }
130+
/// Get default minimum value for this 8-bit floating point type.
131+
int64_t getDefaultMinimum() const { return -getDefaultMaximum(); }
132+
133+
/// Get the storage type as a string.
134+
std::string getStorageType() const { return "f8E5M2"; }
135+
}];
120136
}
121137

122138
//===----------------------------------------------------------------------===//
@@ -143,7 +159,8 @@ def Builtin_Float8E4M3 : Builtin_FloatType<"Float8E4M3", "f8E4M3"> {
143159
// Float8E4M3FNType
144160
//===----------------------------------------------------------------------===//
145161

146-
def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> {
162+
def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN",
163+
["QuantizationInterface"]> {
147164
let summary = "8-bit floating point with 3 bit mantissa";
148165
let description = [{
149166
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
@@ -160,6 +177,21 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> {
160177

161178
Described in: https://arxiv.org/abs/2209.05433
162179
}];
180+
181+
let extraClassDeclaration = [{
182+
/// QuantizationInterface method implementations
183+
bool isStorageSigned() const { return true; }
184+
/// Get the bit width of this 8-bit floating point type.
185+
unsigned getStorageWidth() const { return 8; }
186+
187+
/// Get default maximum value for this 8-bit floating point type.
188+
int64_t getDefaultMaximum() const { return 448; }
189+
/// Get default minimum value for this 8-bit floating point type.
190+
int64_t getDefaultMinimum() const { return -getDefaultMaximum(); }
191+
192+
/// Get the storage type as a string.
193+
std::string getStorageType() const { return "f8E4M3FN"; }
194+
}];
163195
}
164196

165197
//===----------------------------------------------------------------------===//
@@ -561,26 +593,26 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
561593
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
562594

563595
/// QuantizationInterface method implementations
564-
/// Return true if this is a signed integer type.
596+
/// Return true if this is a signed or signless integer type.
565597
bool isStorageSigned() const { return !isUnsigned(); }
566598
/// Get the bit width of this integer type.
567599
unsigned getStorageWidth() const { return getWidth(); }
568600

569-
/// Get default minimum value for this integer type.
570-
int64_t getDefaultMinimum() const {
571-
if (isStorageSigned()) {
572-
return llvm::minIntN(getStorageWidth());
573-
}
574-
return 0;
575-
}
576601
/// Get default maximum value for this integer type.
577602
int64_t getDefaultMaximum() const {
578603
if (isStorageSigned()) {
579604
return llvm::maxIntN(getStorageWidth());
580605
}
581606
return llvm::maxUIntN(getStorageWidth());
582607
}
583-
608+
/// Get default minimum value for this integer type.
609+
int64_t getDefaultMinimum() const {
610+
if (isStorageSigned()) {
611+
return llvm::minIntN(getStorageWidth());
612+
}
613+
return 0;
614+
}
615+
584616
/// Get the storage type as a string.
585617
std::string getStorageType() const {
586618
return (isStorageSigned() ? "i" : "u") + std::to_string(getWidth());

mlir/lib/IR/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@ add_mlir_library(MLIRIR
6464
MLIRCastInterfacesIncGen
6565
MLIRDataLayoutInterfacesIncGen
6666
MLIROpAsmInterfaceIncGen
67+
MLIRQuantizationInterfaceIncGen
6768
MLIRRegionKindInterfaceIncGen
6869
MLIRSideEffectInterfacesIncGen
6970
MLIRSymbolInterfacesIncGen
7071
MLIRTensorEncodingIncGen
71-
MLIRQuantizationInterfaceIncGen
7272

7373
LINK_LIBS PUBLIC
7474
MLIRSupport

0 commit comments

Comments
 (0)