Skip to content

Commit 138fc49

Browse files
committed
Added new type interafce to let UniformQuantizeType accept other than built in types. Updated parser and printer in Quant dialect
1 parent 22f550b commit 138fc49

File tree

10 files changed

+208
-66
lines changed

10 files changed

+208
-66
lines changed

mlir/cmake/modules/AddMLIR.cmake

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,15 @@ macro(add_mlir_generic_tablegen_target target)
216216
add_dependencies(mlir-generic-headers ${target})
217217
endmacro()
218218

219+
# Declare a dialect in the include directory
220+
function(add_mlir_type_interface interface)
221+
set(LLVM_TARGET_DEFINITIONS ${interface}.td)
222+
mlir_tablegen(${interface}.h.inc -gen-type-interface-decls)
223+
mlir_tablegen(${interface}.cpp.inc -gen-type-interface-defs)
224+
add_public_tablegen_target(MLIR${interface}IncGen)
225+
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
226+
endfunction()
227+
219228
# Generate Documentation
220229
function(add_mlir_doc doc_filename output_file output_directory command)
221230
set(LLVM_TARGET_DEFINITIONS ${doc_filename}.td)

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ class BaseMemRefType : public Type,
167167
// Tablegen Type Declarations
168168
//===----------------------------------------------------------------------===//
169169

170+
#include "mlir/IR/QuantizationInterface.h"
171+
170172
#define GET_TYPEDEF_CLASSES
171173
#include "mlir/IR/BuiltinTypes.h.inc"
172174

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
include "mlir/IR/AttrTypeBase.td"
1818
include "mlir/IR/BuiltinDialect.td"
1919
include "mlir/IR/BuiltinTypeInterfaces.td"
20+
include "mlir/IR/QuantizationInterface.td"
2021
include "mlir/IR/CommonTypeConstraints.td"
2122

2223
// TODO: Currently the types defined in this file are prefixed with `Builtin_`.
@@ -501,7 +502,7 @@ def Builtin_Index : Builtin_Type<"Index", "index",
501502
//===----------------------------------------------------------------------===//
502503

503504
def Builtin_Integer : Builtin_Type<"Integer", "integer",
504-
[VectorElementTypeInterface]> {
505+
[VectorElementTypeInterface, QuantizationInterface]> {
505506
let summary = "Integer type with arbitrary precision up to a fixed limit";
506507
let description = [{
507508
Syntax:
@@ -558,6 +559,32 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
558559
/// Integer representation maximal bitwidth.
559560
/// Note: This is aligned with the maximum width of llvm::IntegerType.
560561
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
562+
563+
/// QuantizationInterface method implementations
564+
/// Return true if this is a signed integer type.
565+
bool isStorageSigned() const { return !isUnsigned(); }
566+
/// Get the bit width of this integer type.
567+
unsigned getStorageWidth() const { return getWidth(); }
568+
569+
/// Get default minimum value for this integer type.
570+
int64_t getDefaultMinimum() const {
571+
if (isStorageSigned()) {
572+
return llvm::minIntN(getStorageWidth());
573+
}
574+
return 0;
575+
}
576+
/// Get default maximum value for this integer type.
577+
int64_t getDefaultMaximum() const {
578+
if (isStorageSigned()) {
579+
return llvm::maxIntN(getStorageWidth());
580+
}
581+
return llvm::maxUIntN(getStorageWidth());
582+
}
583+
584+
/// Get the storage type as a string.
585+
std::string getStorageType() const {
586+
return (isStorageSigned() ? "i" : "u") + std::to_string(getWidth());
587+
}
561588
}];
562589
}
563590

mlir/include/mlir/IR/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
add_mlir_interface(SymbolInterfaces)
22
add_mlir_interface(RegionKindInterface)
33

4+
add_mlir_type_interface(QuantizationInterface)
5+
46
set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
57
mlir_tablegen(OpAsmAttrInterface.h.inc -gen-attr-interface-decls)
68
mlir_tablegen(OpAsmAttrInterface.cpp.inc -gen-attr-interface-defs)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- QuantizationInterface.h - Quantzation Interfaces --------*- C++
2+
//-*-===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#ifndef MLIR_IR_QuantizationInterface_H
11+
#define MLIR_IR_QuantizationInterface_H
12+
13+
#include "mlir/IR/Types.h"
14+
15+
// Forward declarations for the types we need in the implementation
16+
namespace mlir {
17+
class IntegerType;
18+
} // namespace mlir
19+
20+
#include "mlir/IR/QuantizationInterface.h.inc"
21+
22+
#endif // MLIR_IR_QuantizationInterface_H
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#ifndef MLIR_IR_QUANTIZATIONINTERFACE
2+
#define MLIR_IR_QUANTIZATIONINTERFACE
3+
4+
include "mlir/IR/OpBase.td"
5+
6+
def QuantizationInterface : TypeInterface<"QuantizationInterface"> {
7+
let description = [{
8+
Interface for types that can be used as storage types in Quant dialect.
9+
This interface provides methods to determine storage characteristics for quantization purposes.
10+
}];
11+
let cppNamespace = "::mlir";
12+
13+
let methods = [
14+
InterfaceMethod<[{
15+
Check if the storage type is signed.
16+
Returns true if the type represents signed values, false for unsigned.
17+
}],
18+
"bool", "isStorageSigned", (ins)>,
19+
20+
InterfaceMethod<[{
21+
Get the bit width of this integer type.
22+
Returns the number of bits used to store values of this type.
23+
}],
24+
"unsigned", "getStorageWidth", (ins)>,
25+
26+
InterfaceMethod<[{
27+
Get default minimum value for this integer type.
28+
}],
29+
"int64_t", "getDefaultMinimum", (ins)>,
30+
31+
InterfaceMethod<[{
32+
Get default maximum value for this integer type.
33+
}],
34+
"int64_t", "getDefaultMaximum", (ins)>,
35+
36+
InterfaceMethod<[{
37+
Get the storage type as a string.
38+
}],
39+
"std::string", "getStorageType", (ins)>
40+
];
41+
42+
}
43+
44+
#endif // MLIR_IR_QUANTIZATIONINTERFACE

mlir/lib/Dialect/Quant/IR/QuantTypes.cpp

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
1010
#include "TypeDetail.h"
1111
#include "mlir/Dialect/Quant/IR/Quant.h"
12+
#include "mlir/IR/QuantizationInterface.h"
1213

1314
#include "mlir/IR/BuiltinTypes.h"
1415
#include "mlir/IR/MLIRContext.h"
@@ -52,26 +53,28 @@ QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
5253
auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
5354
if (!intStorageType)
5455
return emitError() << "storage type must be integral";
55-
unsigned integralWidth = intStorageType.getWidth();
56-
57-
// Verify storage width.
58-
if (integralWidth == 0 || integralWidth > MaxStorageBits)
59-
return emitError() << "illegal storage type size: " << integralWidth;
60-
61-
// Verify storageTypeMin and storageTypeMax.
62-
bool isSigned =
63-
(flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
64-
int64_t defaultIntegerMin =
65-
getDefaultMinimumForInteger(isSigned, integralWidth);
66-
int64_t defaultIntegerMax =
67-
getDefaultMaximumForInteger(isSigned, integralWidth);
68-
if (storageTypeMax - storageTypeMin <= 0 ||
69-
storageTypeMin < defaultIntegerMin ||
70-
storageTypeMax > defaultIntegerMax) {
71-
return emitError() << "illegal storage min and storage max: ("
72-
<< storageTypeMin << ":" << storageTypeMax << ")";
56+
57+
if (auto quantizationInterface =
58+
llvm::dyn_cast<QuantizationInterface>(storageType)) {
59+
unsigned integralWidth = quantizationInterface.getStorageWidth();
60+
61+
// Verify storage width.
62+
if (integralWidth == 0 || integralWidth > MaxStorageBits)
63+
return emitError() << "illegal storage type size: " << integralWidth;
64+
65+
int64_t defaultMin = quantizationInterface.getDefaultMinimum();
66+
int64_t defaultMax = quantizationInterface.getDefaultMaximum();
67+
68+
if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin ||
69+
storageTypeMax > defaultMax) {
70+
return emitError() << "illegal storage min and storage max: ("
71+
<< storageTypeMin << ":" << storageTypeMax << ")";
72+
}
73+
74+
return success();
7375
}
74-
return success();
76+
77+
return emitError() << "storage type must implement QuantizationInterface";
7578
}
7679

7780
Type QuantizedType::getStorageType() const {
@@ -87,20 +90,22 @@ int64_t QuantizedType::getStorageTypeMax() const {
8790
}
8891

8992
bool QuantizedType::hasStorageTypeBounds() const {
90-
unsigned int integralWidth = getStorageTypeIntegralWidth();
91-
bool isSignedInteger = isSigned();
92-
int64_t defaultIntegerMin =
93-
getDefaultMinimumForInteger(isSignedInteger, integralWidth);
94-
int64_t defaultIntegerMax =
95-
getDefaultMaximumForInteger(isSignedInteger, integralWidth);
96-
return defaultIntegerMin != getStorageTypeMin() ||
97-
defaultIntegerMax != getStorageTypeMax();
93+
Type storageType = static_cast<ImplType *>(impl)->storageType;
94+
auto quantizationInterface =
95+
llvm::dyn_cast<QuantizationInterface>(storageType);
96+
97+
int64_t defaultMin = quantizationInterface.getDefaultMinimum();
98+
int64_t defaultMax = quantizationInterface.getDefaultMaximum();
99+
100+
return defaultMin != getStorageTypeMin() || defaultMax != getStorageTypeMax();
98101
}
99102

100103
unsigned QuantizedType::getStorageTypeIntegralWidth() const {
101-
// NOTE: If ever supporting non-integral storage types, some other scheme
102-
// for determining the width will be needed.
103-
return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
104+
Type storageType = static_cast<ImplType *>(impl)->storageType;
105+
auto quantizationInterface =
106+
llvm::dyn_cast<QuantizationInterface>(storageType);
107+
108+
return quantizationInterface.getStorageWidth();
104109
}
105110

106111
Type QuantizedType::getExpressedType() const {

mlir/lib/Dialect/Quant/IR/TypeParser.cpp

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@
1010
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
1111
#include "mlir/IR/BuiltinTypes.h"
1212
#include "mlir/IR/DialectImplementation.h"
13+
#include "mlir/IR/QuantizationInterface.h"
1314
#include "mlir/IR/Types.h"
1415
#include "llvm/ADT/APFloat.h"
1516

1617
using namespace mlir;
1718
using namespace quant;
1819

19-
static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
20+
static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
2021
auto typeLoc = parser.getCurrentLocation();
21-
IntegerType type;
22+
Type type;
2223

2324
// Parse storage type (alpha_ident, integer_literal).
2425
StringRef identifier;
@@ -27,20 +28,28 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
2728
if (result.has_value()) {
2829
if (!succeeded(*result))
2930
return nullptr;
30-
isSigned = !type.isUnsigned();
31-
storageTypeWidth = type.getWidth();
32-
} else if (succeeded(parser.parseKeyword(&identifier))) {
33-
// Otherwise, this must be an unsigned integer (`u` integer-literal).
34-
if (!identifier.consume_front("u")) {
35-
parser.emitError(typeLoc, "illegal storage type prefix");
31+
32+
if (auto quantizationInterface =
33+
llvm::dyn_cast<QuantizationInterface>(type)) {
34+
isSigned = quantizationInterface.isStorageSigned();
35+
storageTypeWidth = quantizationInterface.getStorageWidth();
36+
} else {
37+
parser.emitError(typeLoc, "illegal quantized storage type alias");
3638
return nullptr;
3739
}
38-
if (identifier.getAsInteger(10, storageTypeWidth)) {
39-
parser.emitError(typeLoc, "expected storage type width");
40+
} else if (succeeded(parser.parseKeyword(&identifier))) {
41+
// Otherwise, this must be an unsigned integer (`u` integer-literal)
42+
if (identifier.consume_front("u")) {
43+
if (identifier.getAsInteger(10, storageTypeWidth)) {
44+
parser.emitError(typeLoc, "expected storage type width");
45+
return nullptr;
46+
}
47+
isSigned = false;
48+
type = parser.getBuilder().getIntegerType(storageTypeWidth);
49+
} else {
50+
parser.emitError(typeLoc, "illegal quantized storage type alias");
4051
return nullptr;
4152
}
42-
isSigned = false;
43-
type = parser.getBuilder().getIntegerType(storageTypeWidth);
4453
} else {
4554
return nullptr;
4655
}
@@ -55,17 +64,19 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
5564
return type;
5665
}
5766

58-
static ParseResult parseStorageRange(DialectAsmParser &parser,
59-
IntegerType storageType, bool isSigned,
67+
static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
6068
int64_t &storageTypeMin,
6169
int64_t &storageTypeMax) {
62-
int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
63-
isSigned, storageType.getWidth());
64-
int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
65-
isSigned, storageType.getWidth());
70+
int64_t defaultMin, defaultMax;
71+
if (auto quantizationInterface =
72+
llvm::dyn_cast<QuantizationInterface>(storageType)) {
73+
defaultMin = quantizationInterface.getDefaultMinimum();
74+
defaultMax = quantizationInterface.getDefaultMaximum();
75+
}
76+
6677
if (failed(parser.parseOptionalLess())) {
67-
storageTypeMin = defaultIntegerMin;
68-
storageTypeMax = defaultIntegerMax;
78+
storageTypeMin = defaultMin;
79+
storageTypeMax = defaultMax;
6980
return success();
7081
}
7182

@@ -75,11 +86,11 @@ static ParseResult parseStorageRange(DialectAsmParser &parser,
7586
parser.getCurrentLocation(&maxLoc) ||
7687
parser.parseInteger(storageTypeMax) || parser.parseGreater())
7788
return failure();
78-
if (storageTypeMin < defaultIntegerMin) {
89+
if (storageTypeMin < defaultMin) {
7990
return parser.emitError(minLoc, "illegal storage type minimum: ")
8091
<< storageTypeMin;
8192
}
82-
if (storageTypeMax > defaultIntegerMax) {
93+
if (storageTypeMax > defaultMax) {
8394
return parser.emitError(maxLoc, "illegal storage type maximum: ")
8495
<< storageTypeMax;
8596
}
@@ -113,7 +124,7 @@ static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
113124
/// storage-type ::= (`i` | `u`) integer-literal
114125
/// expressed-type-spec ::= `:` `f` integer-literal
115126
static Type parseAnyType(DialectAsmParser &parser) {
116-
IntegerType storageType;
127+
Type storageType;
117128
FloatType expressedType;
118129
unsigned typeFlags = 0;
119130
int64_t storageTypeMin;
@@ -134,8 +145,7 @@ static Type parseAnyType(DialectAsmParser &parser) {
134145
}
135146

136147
// Storage type range.
137-
if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
138-
storageTypeMax)) {
148+
if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
139149
return nullptr;
140150
}
141151

@@ -322,7 +332,7 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType,
322332
/// scale-zero-tensor (`,` scale-zero-tensor)*
323333
/// `}`
324334
static Type parseUniformType(DialectAsmParser &parser) {
325-
IntegerType storageType;
335+
Type storageType;
326336
FloatType expressedType;
327337
unsigned typeFlags = 0;
328338
int64_t storageTypeMin;
@@ -350,8 +360,7 @@ static Type parseUniformType(DialectAsmParser &parser) {
350360
}
351361

352362
// Storage type range.
353-
if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
354-
storageTypeMax)) {
363+
if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
355364
return nullptr;
356365
}
357366

@@ -487,12 +496,9 @@ Type QuantDialect::parseType(DialectAsmParser &parser) const {
487496

488497
static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
489498
// storage type
490-
unsigned storageWidth = type.getStorageTypeIntegralWidth();
491-
bool isSigned = type.isSigned();
492-
if (isSigned) {
493-
out << "i" << storageWidth;
494-
} else {
495-
out << "u" << storageWidth;
499+
if (auto quantizationInterface =
500+
llvm::dyn_cast<QuantizationInterface>(type.getStorageType())) {
501+
out << quantizationInterface.getStorageType();
496502
}
497503

498504
// storageTypeMin and storageTypeMax if not default.

0 commit comments

Comments
 (0)