Skip to content
This repository was archived by the owner on Nov 27, 2025. It is now read-only.

Commit e30a6c2

Browse files
sartilsramasit
authored andcommitted
Extend Quant dialect with Quantile Quantization type (#53)
* Expanding Quant dialect with Quantile Quantized type * Adding quantile mlir tests * Adding check on quantiles array size and updated mlir tests
1 parent 2cd9b05 commit e30a6c2

File tree

10 files changed

+999
-16
lines changed

10 files changed

+999
-16
lines changed

mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,35 @@ def UniformQuantizedPerAxisType: DialectType<(type
8181
}];
8282
}
8383

84+
def QuantileQuantizedType: DialectType<(type
85+
VarInt:$flags,
86+
Type:$storageType,
87+
Type:$expressedType,
88+
Array<DoubleAPFloatList>:$quantiles,
89+
DoubleAPFloat:$scale,
90+
SignedVarInt:$zeroPoint,
91+
SignedVarInt:$storageTypeMin,
92+
SignedVarInt:$storageTypeMax
93+
)>;
94+
95+
def QuantileQuantizedPerAxisType: DialectType<(type
96+
VarInt:$flags,
97+
Type:$storageType,
98+
Type:$expressedType,
99+
VarInt:$quantizedDimension,
100+
SignedVarInt:$storageTypeMin,
101+
SignedVarInt:$storageTypeMax,
102+
Array<DoubleAPFloatList>:$quantiles,
103+
Array<DoubleAPFloatList>:$scales,
104+
Array<SignedVarIntList>:$zeroPoints
105+
)> {
106+
// Note: builder order differs from bytecode.
107+
let cBuilder = [{
108+
get<$_resultType>(context, flags, storageType, expressedType, quantiles, scales,
109+
zeroPoints, quantizedDimension, storageTypeMin, storageTypeMax)
110+
}];
111+
}
112+
84113
/// This enum contains marker codes used to indicate which attribute is
85114
/// currently being decoded, and how it should be decoded. The order of these
86115
/// codes should generally be unchanged, as any changes will inevitably break
@@ -93,7 +122,9 @@ def QuantDialectTypes : DialectTypes<"Quant"> {
93122
AnyQuantizedTypeWithExpressedType,
94123
CalibratedQuantizedType,
95124
UniformQuantizedType,
96-
UniformQuantizedPerAxisType
125+
UniformQuantizedPerAxisType,
126+
QuantileQuantizedType,
127+
QuantileQuantizedPerAxisType
97128
];
98129
}
99130

mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ struct QuantizedTypeStorage;
2525
struct AnyQuantizedTypeStorage;
2626
struct UniformQuantizedTypeStorage;
2727
struct UniformQuantizedPerAxisTypeStorage;
28+
struct QuantileQuantizedTypeStorage;
29+
struct QuantileQuantizedPerAxisTypeStorage;
2830
struct CalibratedQuantizedTypeStorage;
2931

3032
} // namespace detail
@@ -394,6 +396,128 @@ class UniformQuantizedPerAxisType
394396
}
395397
};
396398

399+
/// QuantileQuantizedType derives from UniformQuantizedType and adds to it a
400+
/// look up table array of quantile values.
401+
///
402+
/// Syntax synopsis:
403+
/// Per-layer, all parameters expressed:
404+
/// !quant<quantile[StorageType:ExpressedType]{Quantiles}:{Scale:ZeroPoint}>
405+
/// Per-layer, optional parameters omitted:
406+
/// !quant<quantile[StorageType]{Quantiles}:{Scale}>
407+
///
408+
/// StorageType: 'i'|'u' NumBits
409+
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
410+
/// Quantiles: Quantile+
411+
/// Quantile: A legal double value
412+
/// Scale: A legal double value
413+
/// ZeroPoint: An integer value
414+
class QuantileQuantizedType
415+
: public Type::TypeBase<QuantileQuantizedType, UniformQuantizedType,
416+
detail::QuantileQuantizedTypeStorage> {
417+
public:
418+
using Base::Base;
419+
using Base::getChecked;
420+
421+
static constexpr StringLiteral name = "quant.quantile";
422+
423+
/// Gets an instance of the type with all parameters specified but not
424+
/// checked.
425+
static QuantileQuantizedType get(unsigned flags, Type storageType,
426+
Type expressedType,
427+
ArrayRef<double> quantiles, double scale,
428+
int64_t zeroPoint, int64_t storageTypeMin,
429+
int64_t storageTypeMax);
430+
431+
static QuantileQuantizedType
432+
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
433+
Type storageType, Type expressedType, ArrayRef<double> quantiles,
434+
double scale, int64_t zeroPoint, int64_t storageTypeMin,
435+
int64_t storageTypeMax);
436+
437+
/// Verifies construction invariants and issues errors/warnings.
438+
static LogicalResult verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
439+
unsigned flags, Type storageType,
440+
Type expressedType, ArrayRef<double> quantiles,
441+
double scale, int64_t zeroPoint,
442+
int64_t storageTypeMin, int64_t storageTypeMax);
443+
444+
/// Gets the quantile values
445+
ArrayRef<double> getQuantiles() const;
446+
447+
// Fixed point values are real numbers divided by a scale.
448+
// Currently, only signed storage types are treated as fixed point.
449+
// A fixed point value can be obtained from an affine value by subtracting
450+
// the zeroPoint.
451+
// In the future, this may be explicit versus implied by type and zeroPoint.
452+
bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; }
453+
};
454+
455+
/// Represents per-axis QuantileQuantizedType (also known as per-channel
456+
/// quantization).
457+
///
458+
/// Syntax synopsis:
459+
/// Per-axis, all parameters expressed:
460+
/// !quant<quantile[StorageType:ExpressedType:QuantizedDim]{Quantiles}:{QuantParams}>
461+
/// Per-axis, optional parameters omitted:
462+
/// !quant<quantile[StorageType]{Quantiles}:{Scale}>
463+
///
464+
/// StorageType: 'i'|'u' NumBits
465+
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
466+
/// QuantizedDim: An integer value
467+
/// Quantiles: Quantile+
468+
/// Quantile: A legal double value
469+
/// QuantParams: (Scale ':' ZeroPoint)+
470+
/// Scale: A legal double value
471+
/// ZeroPoint: An integer value
472+
class QuantileQuantizedPerAxisType
473+
: public Type::TypeBase<QuantileQuantizedPerAxisType,
474+
UniformQuantizedPerAxisType,
475+
detail::QuantileQuantizedPerAxisTypeStorage> {
476+
public:
477+
using Base::Base;
478+
using Base::getChecked;
479+
480+
static constexpr StringLiteral name = "quant.quantile_per_axis";
481+
482+
/// Gets an instance of the type with all parameters specified but not
483+
/// checked.
484+
static QuantileQuantizedPerAxisType
485+
get(unsigned flags, Type storageType, Type expressedType,
486+
ArrayRef<double> quantiles, ArrayRef<double> scales,
487+
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
488+
int64_t storageTypeMin, int64_t storageTypeMax);
489+
490+
/// Gets an instance of the type with all specified parameters checked.
491+
/// Returns a nullptr convertible type on failure.
492+
static QuantileQuantizedPerAxisType
493+
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
494+
Type storageType, Type expressedType, ArrayRef<double> quantiles,
495+
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
496+
int32_t quantizedDimension, int64_t storageTypeMin,
497+
int64_t storageTypeMax);
498+
499+
/// Verifies construction invariants and issues errors/warnings.
500+
static LogicalResult verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
501+
unsigned flags, Type storageType,
502+
Type expressedType, ArrayRef<double> quantiles,
503+
ArrayRef<double> scales,
504+
ArrayRef<int64_t> zeroPoints,
505+
int32_t quantizedDimension,
506+
int64_t storageTypeMin, int64_t storageTypeMax);
507+
508+
/// Gets the quantile values
509+
ArrayRef<double> getQuantiles() const;
510+
511+
/// Fixed point values are real numbers divided by a scale.
512+
/// Currently, only signed storage types are treated as fixed point.
513+
/// A fixed point value can be obtained from an affine value by subtracting
514+
/// the zeroPoint.
515+
/// In the future, this may be explicit versus implied by type and zeroPoint.
516+
bool isFixedPoint() const {
517+
return isSigned() && !llvm::is_contained(getZeroPoints(), 0);
518+
}
519+
};
520+
397521
/// A quantized type that infers its range from given min/max values.
398522
///
399523
/// Typical syntax:
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
//===- QuantOpsBase.td - Quantization dialect base ---------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Predicates for types in the Quantization dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef DIALECT_QUANT_QUANT_OPS_BASE_
14+
#define DIALECT_QUANT_QUANT_OPS_BASE_
15+
16+
include "mlir/IR/OpBase.td"
17+
18+
def Quantization_Dialect : Dialect {
19+
let name = "quant";
20+
let cppNamespace = "::mlir::quant";
21+
22+
let useDefaultTypePrinterParser = 1;
23+
}
24+
25+
//===----------------------------------------------------------------------===//
26+
// Quantization type definitions
27+
//===----------------------------------------------------------------------===//
28+
29+
class quant_TypedPrimitiveOrContainer<Type etype> :
30+
Type<Or<[etype.predicate,
31+
TensorOf<[etype]>.predicate,
32+
VectorOf<[etype]>.predicate]>,
33+
"primitive/tensor/vector of " # etype.summary>;
34+
35+
// An implementation of QuantizedType.
36+
def quant_QuantizedType :
37+
Type<CPred<"::llvm::isa<mlir::quant::QuantizedType>($_self)">, "QuantizedType">;
38+
39+
// A primitive type that can represent a real value. This is either a
40+
// floating point value or a quantized type.
41+
def quant_RealPrimitiveType :
42+
Type<Or<[AnyFloat.predicate, quant_QuantizedType.predicate]>,
43+
"real valued primitive (float or quantized type)">;
44+
45+
// A primitive type that can represent a storage value. This is either an
46+
// integer or quantized type.
47+
def quant_StoragePrimitiveType :
48+
Type<Or<[AnySignlessInteger.predicate, quant_QuantizedType.predicate]>,
49+
"quantized storage primitive (integer or quantized type)">;
50+
51+
// A primitive or container of RealPrimitiveType.
52+
def quant_RealValueType :
53+
quant_TypedPrimitiveOrContainer<quant_RealPrimitiveType>;
54+
55+
// A primitive or container of StoragePrimitiveType.
56+
def quant_StorageValueType :
57+
quant_TypedPrimitiveOrContainer<quant_StoragePrimitiveType>;
58+
59+
// Either a real valued or storage primitive or container type.
60+
def quant_RealOrStorageValueType :
61+
Type<Or<[quant_RealValueType.predicate, quant_StorageValueType.predicate]>,
62+
"real valued or storage primitive or container type">;
63+
64+
// An implementation of UniformQuantizedType.
65+
def quant_UniformQuantizedType :
66+
DialectType<Quantization_Dialect,
67+
CPred<"::llvm::isa<UniformQuantizedType>($_self)">,
68+
"UniformQuantizedType">;
69+
70+
// An implementation of QuantileQuantizedType.
71+
def quant_QuantileQuantizedType :
72+
DialectType<Quantization_Dialect,
73+
CPred<"::llvm::isa<QuantileQuantizedType>($_self)">,
74+
"QuantileQuantizedType">;
75+
76+
// Predicate for detecting a container or primitive of UniformQuantizedType.
77+
def quant_UniformQuantizedValueType :
78+
quant_TypedPrimitiveOrContainer<quant_UniformQuantizedType>;
79+
80+
// Predicate for detecting a container or primitive of QuantileQuantizedType.
81+
def quant_QuantileQuantizedValueType :
82+
quant_TypedPrimitiveOrContainer<quant_QuantileQuantizedType>;
83+
84+
#endif // DIALECT_QUANT_QUANT_OPS_BASE_

mlir/lib/Dialect/Quant/IR/QuantOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
9393

9494
void QuantDialect::initialize() {
9595
addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
96-
UniformQuantizedPerAxisType>();
96+
UniformQuantizedPerAxisType, QuantileQuantizedType,
97+
QuantileQuantizedPerAxisType>();
9798
addOperations<
9899
#define GET_OP_LIST
99100
#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"

0 commit comments

Comments
 (0)