Skip to content

Commit f1d9171

Browse files
sdasgup3ermilindwalekar
authored andcommitted
Sub-channel quantized type implementation (llvm#120172)
This is an implementation for [RFC: Supporting Sub-Channel Quantization in MLIR](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694). In order to make the review process easier, the PR has been divided into the following commit labels: 1. **Add implementation for sub-channel type:** Includes the class design for `UniformQuantizedSubChannelType`, printer/parser and bytecode read/write support. The existing types (per-tensor and per-axis) are unaltered. 2. **Add implementation for sub-channel type:** Lowering of `quant.qcast` and `quant.dcast` operations to Linalg operations. 3. **Adding C/Python Apis:** We first define he C-APIs and build the Python-APIs on top of those. 4. **Add pass to normalize generic ....:** This pass normalizes sub-channel quantized types to per-tensor per-axis types, if possible. A design note: - **Explicitly storing the `quantized_dimensions`, even when they can be derived for ranked tensor.** While it's possible to infer quantized dimensions from the static shape of the scales (or zero-points) tensor for ranked data tensors ([ref](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694/3) for background), there are cases where this can lead to ambiguity and issues with round-tripping. ``` Consider the example: tensor<2x4x!quant.uniform<i8:f32:{0:2, 0:2}, {{s00:z00, s01:z01}}>> ``` The shape of the scales tensor is [1, 2], which might suggest that only axis 1 is quantized. While this inference is technically correct, as the block size for axis 0 is a degenerate case (equal to the dimension size), it can cause problems with round-tripping. Therefore, even for ranked tensors, we are explicitly storing the quantized dimensions. Suggestions welcome! PS: I understand that the upcoming holidays may impact your schedule, so please take your time with the review. There's no rush.
1 parent 6aa4a66 commit f1d9171

File tree

13 files changed

+835
-95
lines changed

13 files changed

+835
-95
lines changed

mlir/include/mlir/Dialect/Quant/IR/QuantBase.td

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,110 @@ def Quant_Dialect : Dialect {
279279
// Correct. The quantized type now includes 3 scale values, matching the
280280
// size of dimension 1 of the result tensor.
281281
%result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {2.0, 3.0, 4.0}>>
282+
283+
## Sub-channel quantization integrity
284+
285+
When type `!quant.uniform` contains sub-channel quantization information,
286+
the following rules are enforced. For efficiency, these rules are actively
287+
enforced by the verifiers of `quant` dialect ops, but they must be
288+
respected in any context in which the `!quant.uniform` data type is used,
289+
such as the header of a `func.func` op, or the input of an arithmetic
290+
operation.
291+
292+
- A quantized type with sub-channel quantization information must be the
293+
element type of a tensor container type, and may not occur directly as
294+
the data type of a scalar value.
295+
296+
```
297+
// Incorrect. Type !quant.uniform specifies sub-channel quantization for a
298+
// scalar type.
299+
%result = quant.qcast %input : f32 to !quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>
300+
301+
// Correct. Type `!quant.uniform` with sub-channel quantization is wrapped
302+
// in a `tensor` type.
303+
%result = quant.qcast %input : tensor<2x2xf32> to
304+
tensor<2x2x!quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>>
305+
```
306+
307+
- The tensor containing the sub-channel quantized type must be ranked.
308+
309+
```
310+
// Incorrect. Type !quant.uniform specifies sub-channel quantization for a
311+
// unranked tensor type.
312+
%result = quant.qcast %input : tensor<*xf32> to
313+
tensor<*x!quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>>
314+
```
315+
316+
- The axis for which a block size is specified should be valid for a tensor
317+
of a given rank. Block sizes can be specified for a subset of axes.
318+
Any unspecified block size for an axis i defaults to the tensor dimension
319+
size of that axis (shape(tensor)[i]).
320+
321+
```
322+
// Incorrect. The block-size is specified for axis 2 which is greater than
323+
// the rank of the tensor.
324+
%result = quant.qcast %input : tensor<2x2xf32> to
325+
tensor<2x2x!quant.uniform<i8:f32:{2:1, 1:2}, {{1.0}, {2.0}}>>
326+
327+
// Incorrect. The block-size is specified for a negative axis.
328+
%result = quant.qcast %input : tensor<2x2xf32> to
329+
tensor<2x2x!quant.uniform<i8:f32:{-1:1, 1:2}, {{1.0}, {2.0}}>>
330+
331+
// Correct. The block size for axis 1 is skipped which should be assumed as
332+
// 2, the dim-size of tensor at axis 1.
333+
%result = quant.qcast %input : tensor<6x2xf32> to
334+
tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {3.0}}>>
335+
336+
// Correct. The block size for all the axes are skipped making the
337+
// sub-channel type essentially a per-tensor type.
338+
%result = quant.qcast %input : tensor<6x2xf32> to
339+
tensor<6x2x!quant.uniform<i8:f32:{}, {{1.0}}>>
340+
```
341+
342+
- Block size for a particular axis should be a positive integer and should
343+
be less than the dimension size of the tensor along that axis.
344+
345+
```
346+
// Incorrect. The block size for axis 0 is -1.
347+
%result = quant.qcast %input : tensor<6x2xf32> to
348+
tensor<6x2x!quant.uniform<i8:f32:{0:-1}, {{1.0, 2.0}}>>
349+
350+
// Incorrect. The block size for axis 0 is 8 which is greater than the
351+
// dimension size of tensor at axis 0 (which is 6).
352+
%result = quant.qcast %input : tensor<6x2xf32> to
353+
tensor<6x2x!quant.uniform<i8:f32:{0:8}, {{1.0, 2.0}}>>
354+
355+
// Correct. The block size for axis 0 is now 3.
356+
%result = quant.qcast %input : tensor<6x2xf32> to
357+
tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
358+
```
359+
360+
- shape(tensor) % blockSizes = 0 where blockSizes = [block sizes for
361+
axis i in [0, 1, ..., rank(tensor)-1]].
362+
363+
```
364+
// Incorrect. The block size for axis 0 is 4 and the corresponding
365+
// dimension size is 6 and 6 % 4 != 0.
366+
%result = quant.qcast %input : tensor<6x2xf32> to
367+
tensor<6x2x!quant.uniform<i8:f32:{0:4}, {{1.0, 2.0}}>>
368+
369+
// Correct. The block size for axis 0 is now 3 making 6 % 3 = 0.
370+
%result = quant.qcast %input : tensor<6x2xf32> to
371+
tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
372+
```
373+
374+
- shape(scales) = shape(zeroPoints) = shape(tensor) / blockSizes.
375+
376+
```
377+
// Incorrect. shape(tensor) = [6,2], blockSizes = [3,2], but
378+
// shape(scales) is [1,2] which is not equal to [6,2]/[3,2].
379+
%result = quant.qcast %input : tensor<6x2xf32> to
380+
tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0, 2.0}}>>
381+
382+
// Correct. shape(tensor) = [6,2], blockSizes = [3,2], and
383+
// shape(scales) equals [6,2]/[3,2].
384+
%result = quant.qcast %input : tensor<6x2xf32> to
385+
tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
282386
```
283387

284388
## Sub-channel quantization integrity

mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,21 @@ def UniformQuantizedPerAxisType: DialectType<(type
8282
}];
8383
}
8484

85+
def UniformQuantizedSubChannelType
86+
: DialectType<(type VarInt:$flags, Type:$storageType, Type:$expressedType,
87+
SignedVarInt:$storageTypeMin, SignedVarInt:$storageTypeMax,
88+
Array<SignedVarIntList>:$quantizedDimensions,
89+
Array<SignedVarIntList>:$blockSizes, DenseElementsAttr:$scales,
90+
DenseElementsAttr:$zeroPoints)> {
91+
// Note: builder order differs from bytecode.
92+
let cBuilder = [{
93+
get<$_resultType>(context, flags, storageType, expressedType, scales,
94+
zeroPoints, llvm::to_vector(llvm::map_range(quantizedDimensions,
95+
[](int64_t dim) { return static_cast<int32_t>(dim);})), blockSizes,
96+
storageTypeMin, storageTypeMax)
97+
}];
98+
}
99+
85100
def QuantileQuantizedType: DialectType<(type
86101
VarInt:$flags,
87102
Type:$storageType,
@@ -119,16 +134,12 @@ def QuantileQuantizedPerAxisType: DialectType<(type
119134
/// compatibility with older bytecode.
120135

121136
def QuantDialectTypes : DialectTypes<"Quant"> {
122-
let elems = [
123-
ReservedOrDead,
124-
AnyQuantizedType,
125-
AnyQuantizedTypeWithExpressedType,
126-
CalibratedQuantizedType,
127-
UniformQuantizedType,
128-
UniformQuantizedPerAxisType,
137+
let elems = [ReservedOrDead, AnyQuantizedType,
138+
AnyQuantizedTypeWithExpressedType, CalibratedQuantizedType,
139+
UniformQuantizedType, UniformQuantizedPerAxisType,
129140
QuantileQuantizedType,
130-
QuantileQuantizedPerAxisType
131-
];
141+
QuantileQuantizedPerAxisType,
142+
UniformQuantizedSubChannelType];
132143
}
133144

134145
#endif // QUANT_BYTECODE

mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,136 @@ class UniformQuantizedPerAxisType
401401
}
402402
};
403403

404+
/// Represents sub-channel (also known as blockwise quantization).
405+
///
406+
/// Syntax synopsis:
407+
/// UniformQuantizedSubChannelType ::= '!quant.uniform' '<'
408+
/// storageType ('<' storageMin ':' storageMax '>')? ':'
409+
/// expressedType ':' BlockSizeInfo ',' ScaleZeroTensor '>'
410+
/// BlockSizeInfo: '{' '}' | '{' AxisBlock (',' AxisBlock)* '}'
411+
/// AxisBlock ::= AxisSpec ':' BlockSizeSpec
412+
/// ScaleZeroTensor ::= ScaleZeroDenseExp | ScaleZeroList
413+
/// ScaleZeroDenseExp ::= '{' ScaleZeroTensor (',' ScaleZeroTensor)* '}'
414+
/// ScaleZeroList ::= ScaleZero (',' ScaleZero)*
415+
/// ScaleZero ::= Scale (':' ZeroPoint)?
416+
///
417+
/// StorageType: 'i'|'u' NumBits
418+
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
419+
/// AxisSpec: An integer value
420+
/// BlockSizeSpec: An integer value
421+
/// Scale: An attribute (usually floating-point value)
422+
/// ZeroPoint: An attribute (usually integer value)
423+
class UniformQuantizedSubChannelType
424+
: public Type::TypeBase<UniformQuantizedSubChannelType, QuantizedType,
425+
detail::UniformQuantizedSubChannelTypeStorage> {
426+
public:
427+
using Base::Base;
428+
using Base::getChecked;
429+
430+
static constexpr StringLiteral name = "quant.uniform_sub_channel";
431+
432+
/// Gets an instance of the type with all parameters specified but not
433+
/// checked.
434+
static UniformQuantizedSubChannelType
435+
get(unsigned flags, Type storageType, Type expressedType,
436+
DenseElementsAttr scales, DenseElementsAttr zeroPoints,
437+
ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes,
438+
int64_t storageTypeMin, int64_t storageTypeMax);
439+
440+
/// Gets an instance of the type with all specified parameters checked.
441+
/// Returns a nullptr convertible type on failure.
442+
static UniformQuantizedSubChannelType
443+
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
444+
Type storageType, Type expressedType, DenseElementsAttr scales,
445+
DenseElementsAttr zeroPoints,
446+
ArrayRef<int32_t> quantizedDimensions,
447+
ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
448+
int64_t storageTypeMax);
449+
450+
/// Verifies construction invariants and issues errors/warnings.
451+
static LogicalResult
452+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
453+
Type storageType, Type expressedType,
454+
DenseElementsAttr scales, DenseElementsAttr zeroPoints,
455+
ArrayRef<int32_t> quantizedDimensions,
456+
ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
457+
int64_t storageTypeMax);
458+
459+
/// Gets the quantization scales. The scales are organized in a
460+
/// multi-dimensional tensor. The size of each dimension in the scales tensor
461+
/// is determined by the number of blocks along the corresponding dimension in
462+
/// the quantized data tensor.
463+
///
464+
/// For example, if the quantized data tensor has shape [X0, X1, ..., XR-1]
465+
/// and the block sizes are [B0, B1, ..., BR-1], then the scales tensor will
466+
/// have shape [X0/B0, X1/B1, ..., XR-1/BR-1].
467+
///
468+
/// The scale value for a specific element in the quantized data tensor at
469+
/// position [i0, i1, ..., iR-1] is determined by accessing the corresponding
470+
/// element in the scales tensor at position [i0/B0, i1/B1, ..., iR-1/BR-1].
471+
DenseElementsAttr getScales() const;
472+
473+
/// Gets the quantization zero-points. The zero-points are organized in a
474+
/// multi-dimensional tensor. The size of each dimension in the zero-point
475+
/// tensor is determined by the number of blocks along the corresponding
476+
/// dimension in the quantized data tensor.
477+
///
478+
/// For example, if the quantized data tensor has shape [X0, X1, ..., XR-1]
479+
/// and the block sizes are [B0, B1, ..., BR-1], then the zero-point tensor
480+
/// will have shape [X0/B0, X1/B1, ..., XR-1/BR-1].
481+
///
482+
/// The zero-point value for a specific element in the quantized data tensor
483+
/// at position [i0, i1, ..., iR-1] is determined by accessing the
484+
/// corresponding element in the zero-point tensor at position [i0/B0, i1/B1,
485+
/// ..., iR-1/BR-1].
486+
DenseElementsAttr getZeroPoints() const;
487+
488+
/// Gets the quantized dimensions. Each element in the returned list
489+
/// represents an axis of the quantized data tensor that has a specified block
490+
/// size. The order of elements corresponds to the order of block sizes
491+
/// returned by `getBlockSizes()`.
492+
///
493+
/// It means that the data tensor is quantized along the `i`-th dimension in
494+
/// the returned list using the `i`-th block size from `getBlockSizes()`.
495+
///
496+
/// Note that the type expression does not have to specify the block size for
497+
/// all axes in the data tensor. Any unspecified block size for an axis `i`
498+
/// defaults to the tensor dimension size of that axis.
499+
///
500+
/// For example, for a quantized type:
501+
/// `tensor<8x4x2x!quant.uniform<i8:f32:{1:2, 0:8}, {{1.0, 2.0}, {3.0, 4.0}}>`
502+
///
503+
/// `getQuantizedDimensions()` returns [1, 0].
504+
/// `getBlockSizes()` returns [2, 8].
505+
///
506+
/// This indicates that:
507+
/// * Axis 1 (second dimension) is quantized with a block size of 2.
508+
/// * Axis 0 (first dimension) is quantized with a block size of 8.
509+
/// Since axis 2 is not specified, it implicitly has a block size equal to
510+
/// the size of the third dimension (which is 2 in this case).
511+
ArrayRef<int32_t> getQuantizedDimensions() const;
512+
513+
/// Gets the block sizes for the quantized dimensions. The `i`-th element in
514+
/// the returned list corresponds to the block size for the `i`-th dimension
515+
/// in the list returned by `getQuantizedDimensions()`.
516+
///
517+
/// See `getQuantizedDimensions()` for more details and examples.
518+
ArrayRef<int64_t> getBlockSizes() const;
519+
520+
/// Gets the block size information. This returns a list of pairs, where each
521+
/// pair represents a quantized dimension and its corresponding block size.
522+
///
523+
/// For example, for the type:
524+
/// `tensor<8x4x!quant.uniform<i8:f32:{1:2, 0:8}, {{2.0, 3.0}}>`
525+
///
526+
/// This method returns:
527+
/// `[(1, 2), (0, 8)]`
528+
///
529+
/// This list indicates that axis 1 has a block size of 2, and axis 0 has a
530+
/// block size of 8.
531+
const SmallVector<std::pair<int32_t, int64_t>> getBlockSizeInfo() const;
532+
};
533+
404534
/// QuantileQuantizedType derives from UniformQuantizedType and adds to it a
405535
/// look up table array of quantile values. The type of the data in the look up
406536
/// table is determined by the quantileType member: supported quantileType types

mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
1414
#include "mlir/IR/Diagnostics.h"
1515
#include "llvm/ADT/APFloat.h"
16+
#include "llvm/ADT/STLExtras.h"
17+
#include "llvm/ADT/SmallVector.h"
1618
#include "llvm/ADT/TypeSwitch.h"
1719

1820
using namespace mlir;

mlir/lib/Dialect/Quant/IR/QuantOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ LogicalResult verifySubChannelQuantization(
122122
//
123123
// Therefore, we explicitly disallow the case where d = 0 to maintain
124124
// consistency and avoid these issues.
125-
if (llvm::is_contained(tensorType.getShape(), 0)) {
125+
if (llvm::find(tensorType.getShape(), 0) != tensorType.getShape().end()) {
126126
return op->emitError() << "tensor dimension size of zero is not allowed "
127127
"with sub-channel quantization";
128128
}
@@ -192,7 +192,7 @@ LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
192192
void QuantDialect::initialize() {
193193
addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
194194
UniformQuantizedPerAxisType, QuantileQuantizedType,
195-
QuantileQuantizedPerAxisType>();
195+
QuantileQuantizedPerAxisType, UniformQuantizedSubChannelType>();
196196
addOperations<
197197
#define GET_OP_LIST
198198
#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"

0 commit comments

Comments
 (0)