Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/cmake/modules/AddMLIR.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,14 @@ function(add_mlir_interface interface)
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
endfunction()

# Declare a dialect in the include directory
function(add_mlir_type_interface interface)
set(LLVM_TARGET_DEFINITIONS ${interface}.td)
mlir_tablegen(${interface}.h.inc -gen-type-interface-decls)
mlir_tablegen(${interface}.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIR${interface}IncGen)
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
endfunction()

# Generate Documentation
function(add_mlir_doc doc_filename output_file output_directory command)
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ class BaseMemRefType : public Type,
// Tablegen Type Declarations
//===----------------------------------------------------------------------===//

#include "mlir/IR/QuantizationInterface.h"

#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.h.inc"

Expand Down
29 changes: 28 additions & 1 deletion mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/IR/QuantizationInterface.td"
include "mlir/IR/CommonTypeConstraints.td"

// TODO: Currently the types defined in this file are prefixed with `Builtin_`.
Expand Down Expand Up @@ -497,7 +498,7 @@ def Builtin_Index : Builtin_Type<"Index", "index",
//===----------------------------------------------------------------------===//

def Builtin_Integer : Builtin_Type<"Integer", "integer",
[VectorElementTypeInterface]> {
[VectorElementTypeInterface, QuantizationInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
Expand Down Expand Up @@ -554,6 +555,32 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
/// Integer representation maximal bitwidth.
/// Note: This is aligned with the maximum width of llvm::IntegerType.
static constexpr unsigned kMaxWidth = (1 << 24) - 1;

/// QuantizationInterface method implementations
/// Return true if this is a signed integer type.
bool isStorageSigned() const { return !isUnsigned(); }
/// Get the bit width of this integer type.
unsigned getStorageWidth() const { return getWidth(); }

/// Get default minimum value for this integer type.
int64_t getDefaultMinimum() const {
if (isStorageSigned()) {
return llvm::minIntN(getStorageWidth());
}
return 0;
}
/// Get default maximum value for this integer type.
int64_t getDefaultMaximum() const {
if (isStorageSigned()) {
return llvm::maxIntN(getStorageWidth());
}
return llvm::maxUIntN(getStorageWidth());
}

/// Get the storage type as a string.
std::string getStorageType() const {
return (isStorageSigned() ? "i" : "u") + std::to_string(getWidth());
}
}];
}

Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
add_mlir_interface(SymbolInterfaces)
add_mlir_interface(RegionKindInterface)

add_mlir_type_interface(QuantizationInterface)

set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
mlir_tablegen(OpAsmAttrInterface.h.inc -gen-attr-interface-decls)
mlir_tablegen(OpAsmAttrInterface.cpp.inc -gen-attr-interface-defs)
Expand Down
22 changes: 22 additions & 0 deletions mlir/include/mlir/IR/QuantizationInterface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===- QuantizationInterface.h - Quantzation Interfaces --------*- C++
//-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_IR_QuantizationInterface_H
#define MLIR_IR_QuantizationInterface_H

#include "mlir/IR/Types.h"

// Forward declarations for the types we need in the implementation
namespace mlir {
class IntegerType;
} // namespace mlir

#include "mlir/IR/QuantizationInterface.h.inc"

#endif // MLIR_IR_QuantizationInterface_H
44 changes: 44 additions & 0 deletions mlir/include/mlir/IR/QuantizationInterface.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#ifndef MLIR_IR_QUANTIZATIONINTERFACE
#define MLIR_IR_QUANTIZATIONINTERFACE

include "mlir/IR/OpBase.td"

def QuantizationInterface : TypeInterface<"QuantizationInterface"> {
let description = [{
Interface for types that can be used as storage types in Quant dialect.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, this interface would live in the quant dialect. It would then be attached to the IntegerType as an external model. Unfortunately, that makes it impossible to declare the interface as "promised" due to layering constraints.

This interface provides methods to determine storage characteristics for quantization purposes.
}];
let cppNamespace = "::mlir";

let methods = [
InterfaceMethod<[{
Check if the storage type is signed.
Returns true if the type represents signed values, false for unsigned.
}],
"bool", "isStorageSigned", (ins)>,

InterfaceMethod<[{
Get the bit width of this integer type.
Returns the number of bits used to store values of this type.
}],
"unsigned", "getStorageWidth", (ins)>,

InterfaceMethod<[{
Get default minimum value for this integer type.
}],
"int64_t", "getDefaultMinimum", (ins)>,

InterfaceMethod<[{
Get default maximum value for this integer type.
}],
"int64_t", "getDefaultMaximum", (ins)>,

InterfaceMethod<[{
Get the storage type as a string.
}],
"std::string", "getStorageType", (ins)>
];

}

#endif // MLIR_IR_QUANTIZATIONINTERFACE
65 changes: 35 additions & 30 deletions mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "TypeDetail.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/IR/QuantizationInterface.h"

#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
Expand Down Expand Up @@ -52,26 +53,28 @@ QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
if (!intStorageType)
return emitError() << "storage type must be integral";
unsigned integralWidth = intStorageType.getWidth();

// Verify storage width.
if (integralWidth == 0 || integralWidth > MaxStorageBits)
return emitError() << "illegal storage type size: " << integralWidth;

// Verify storageTypeMin and storageTypeMax.
bool isSigned =
(flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
int64_t defaultIntegerMin =
getDefaultMinimumForInteger(isSigned, integralWidth);
int64_t defaultIntegerMax =
getDefaultMaximumForInteger(isSigned, integralWidth);
if (storageTypeMax - storageTypeMin <= 0 ||
storageTypeMin < defaultIntegerMin ||
storageTypeMax > defaultIntegerMax) {
return emitError() << "illegal storage min and storage max: ("
<< storageTypeMin << ":" << storageTypeMax << ")";

if (auto quantizationInterface =
llvm::dyn_cast<QuantizationInterface>(storageType)) {
unsigned integralWidth = quantizationInterface.getStorageWidth();

// Verify storage width.
if (integralWidth == 0 || integralWidth > MaxStorageBits)
return emitError() << "illegal storage type size: " << integralWidth;

int64_t defaultMin = quantizationInterface.getDefaultMinimum();
int64_t defaultMax = quantizationInterface.getDefaultMaximum();

if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin ||
storageTypeMax > defaultMax) {
return emitError() << "illegal storage min and storage max: ("
<< storageTypeMin << ":" << storageTypeMax << ")";
}

return success();
}
return success();

return emitError() << "storage type must implement QuantizationInterface";
}

Type QuantizedType::getStorageType() const {
Expand All @@ -87,20 +90,22 @@ int64_t QuantizedType::getStorageTypeMax() const {
}

bool QuantizedType::hasStorageTypeBounds() const {
unsigned int integralWidth = getStorageTypeIntegralWidth();
bool isSignedInteger = isSigned();
int64_t defaultIntegerMin =
getDefaultMinimumForInteger(isSignedInteger, integralWidth);
int64_t defaultIntegerMax =
getDefaultMaximumForInteger(isSignedInteger, integralWidth);
return defaultIntegerMin != getStorageTypeMin() ||
defaultIntegerMax != getStorageTypeMax();
Type storageType = static_cast<ImplType *>(impl)->storageType;
auto quantizationInterface =
llvm::dyn_cast<QuantizationInterface>(storageType);

int64_t defaultMin = quantizationInterface.getDefaultMinimum();
int64_t defaultMax = quantizationInterface.getDefaultMaximum();

return defaultMin != getStorageTypeMin() || defaultMax != getStorageTypeMax();
}

unsigned QuantizedType::getStorageTypeIntegralWidth() const {
// NOTE: If ever supporting non-integral storage types, some other scheme
// for determining the width will be needed.
return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
Type storageType = static_cast<ImplType *>(impl)->storageType;
auto quantizationInterface =
llvm::dyn_cast<QuantizationInterface>(storageType);

return quantizationInterface.getStorageWidth();
}

Type QuantizedType::getExpressedType() const {
Expand Down
74 changes: 40 additions & 34 deletions mlir/lib/Dialect/Quant/IR/TypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/QuantizationInterface.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/APFloat.h"

using namespace mlir;
using namespace quant;

static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
auto typeLoc = parser.getCurrentLocation();
IntegerType type;
Type type;

// Parse storage type (alpha_ident, integer_literal).
StringRef identifier;
Expand All @@ -27,20 +28,28 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
if (result.has_value()) {
if (!succeeded(*result))
return nullptr;
isSigned = !type.isUnsigned();
storageTypeWidth = type.getWidth();
} else if (succeeded(parser.parseKeyword(&identifier))) {
// Otherwise, this must be an unsigned integer (`u` integer-literal).
if (!identifier.consume_front("u")) {
parser.emitError(typeLoc, "illegal storage type prefix");

if (auto quantizationInterface =
llvm::dyn_cast<QuantizationInterface>(type)) {
isSigned = quantizationInterface.isStorageSigned();
storageTypeWidth = quantizationInterface.getStorageWidth();
} else {
parser.emitError(typeLoc, "illegal quantized storage type alias");
return nullptr;
}
if (identifier.getAsInteger(10, storageTypeWidth)) {
parser.emitError(typeLoc, "expected storage type width");
} else if (succeeded(parser.parseKeyword(&identifier))) {
// Otherwise, this must be an unsigned integer (`u` integer-literal)
if (identifier.consume_front("u")) {
if (identifier.getAsInteger(10, storageTypeWidth)) {
parser.emitError(typeLoc, "expected storage type width");
return nullptr;
}
isSigned = false;
type = parser.getBuilder().getIntegerType(storageTypeWidth);
} else {
parser.emitError(typeLoc, "illegal quantized storage type alias");
return nullptr;
}
isSigned = false;
type = parser.getBuilder().getIntegerType(storageTypeWidth);
} else {
return nullptr;
}
Expand All @@ -55,17 +64,19 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
return type;
}

static ParseResult parseStorageRange(DialectAsmParser &parser,
IntegerType storageType, bool isSigned,
static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
int64_t &storageTypeMin,
int64_t &storageTypeMax) {
int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
isSigned, storageType.getWidth());
int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
isSigned, storageType.getWidth());
int64_t defaultMin, defaultMax;
if (auto quantizationInterface =
llvm::dyn_cast<QuantizationInterface>(storageType)) {
defaultMin = quantizationInterface.getDefaultMinimum();
defaultMax = quantizationInterface.getDefaultMaximum();
}

if (failed(parser.parseOptionalLess())) {
storageTypeMin = defaultIntegerMin;
storageTypeMax = defaultIntegerMax;
storageTypeMin = defaultMin;
storageTypeMax = defaultMax;
return success();
}

Expand All @@ -75,11 +86,11 @@ static ParseResult parseStorageRange(DialectAsmParser &parser,
parser.getCurrentLocation(&maxLoc) ||
parser.parseInteger(storageTypeMax) || parser.parseGreater())
return failure();
if (storageTypeMin < defaultIntegerMin) {
if (storageTypeMin < defaultMin) {
return parser.emitError(minLoc, "illegal storage type minimum: ")
<< storageTypeMin;
}
if (storageTypeMax > defaultIntegerMax) {
if (storageTypeMax > defaultMax) {
return parser.emitError(maxLoc, "illegal storage type maximum: ")
<< storageTypeMax;
}
Expand Down Expand Up @@ -113,7 +124,7 @@ static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
/// storage-type ::= (`i` | `u`) integer-literal
/// expressed-type-spec ::= `:` `f` integer-literal
static Type parseAnyType(DialectAsmParser &parser) {
IntegerType storageType;
Type storageType;
FloatType expressedType;
unsigned typeFlags = 0;
int64_t storageTypeMin;
Expand All @@ -134,8 +145,7 @@ static Type parseAnyType(DialectAsmParser &parser) {
}

// Storage type range.
if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
storageTypeMax)) {
if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
return nullptr;
}

Expand Down Expand Up @@ -322,7 +332,7 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType,
/// scale-zero-tensor (`,` scale-zero-tensor)*
/// `}`
static Type parseUniformType(DialectAsmParser &parser) {
IntegerType storageType;
Type storageType;
FloatType expressedType;
unsigned typeFlags = 0;
int64_t storageTypeMin;
Expand Down Expand Up @@ -350,8 +360,7 @@ static Type parseUniformType(DialectAsmParser &parser) {
}

// Storage type range.
if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
storageTypeMax)) {
if (parseStorageRange(parser, storageType, storageTypeMin, storageTypeMax)) {
return nullptr;
}

Expand Down Expand Up @@ -486,12 +495,9 @@ Type QuantDialect::parseType(DialectAsmParser &parser) const {

static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
// storage type
unsigned storageWidth = type.getStorageTypeIntegralWidth();
bool isSigned = type.isSigned();
if (isSigned) {
out << "i" << storageWidth;
} else {
out << "u" << storageWidth;
if (auto quantizationInterface =
llvm::dyn_cast<QuantizationInterface>(type.getStorageType())) {
out << quantizationInterface.getStorageType();
}

// storageTypeMin and storageTypeMax if not default.
Expand Down
Loading