Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit 8cc42ae

Browse files
authored
Sub-channel quantized type implementation (#120172)
This is an implementation for [RFC: Supporting Sub-Channel Quantization in MLIR](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694). In order to make the review process easier, the PR has been divided into the following commit labels: 1. **Add implementation for sub-channel type:** Includes the class design for `UniformQuantizedSubChannelType`, printer/parser and bytecode read/write support. The existing types (per-tensor and per-axis) are unaltered. 2. **Add implementation for sub-channel type:** Lowering of `quant.qcast` and `quant.dcast` operations to Linalg operations. 3. **Adding C/Python Apis:** We first define he C-APIs and build the Python-APIs on top of those. 4. **Add pass to normalize generic ....:** This pass normalizes sub-channel quantized types to per-tensor per-axis types, if possible. A design note: - **Explicitly storing the `quantized_dimensions`, even when they can be derived for ranked tensor.** While it's possible to infer quantized dimensions from the static shape of the scales (or zero-points) tensor for ranked data tensors ([ref](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694/3) for background), there are cases where this can lead to ambiguity and issues with round-tripping. ``` Consider the example: tensor<2x4x!quant.uniform<i8:f32:{0:2, 0:2}, {{s00:z00, s01:z01}}>> ``` The shape of the scales tensor is [1, 2], which might suggest that only axis 1 is quantized. While this inference is technically correct, as the block size for axis 0 is a degenerate case (equal to the dimension size), it can cause problems with round-tripping. Therefore, even for ranked tensors, we are explicitly storing the quantized dimensions. Suggestions welcome! PS: I understand that the upcoming holidays may impact your schedule, so please take your time with the review. There's no rush.
1 parent 45539c6 commit 8cc42ae

File tree

4 files changed

+193
-2
lines changed

4 files changed

+193
-2
lines changed

mlir/include/mlir-c/Dialect/Quant.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,47 @@ mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type);
172172
MLIR_CAPI_EXPORTED bool
173173
mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type);
174174

175+
//===---------------------------------------------------------------------===//
176+
// UniformQuantizedSubChannelType
177+
//===---------------------------------------------------------------------===//
178+
179+
/// Returns `true` if the given type is a UniformQuantizedSubChannel.
180+
MLIR_CAPI_EXPORTED bool
181+
mlirTypeIsAUniformQuantizedSubChannelType(MlirType type);
182+
183+
/// Creates a UniformQuantizedSubChannelType with the given parameters.
184+
///
185+
/// The type is owned by the context. `scalesAttr` and `zeroPointsAttr` must be
186+
/// DenseElementsAttrs. `quantizedDimensions` and `blockSizes`
187+
/// point to `blockSizeInfoLength` number of elements, describing respectively
188+
/// the quantization axis and corresponding block size.
189+
MLIR_CAPI_EXPORTED MlirType mlirUniformQuantizedSubChannelTypeGet(
190+
unsigned flags, MlirType storageType, MlirType expressedType,
191+
MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr,
192+
intptr_t blockSizeInfoLength, int32_t *quantizedDimensions,
193+
int64_t *blockSizes, int64_t storageTypeMin, int64_t storageTypeMax);
194+
195+
/// Returns the number of block sizes provided in type.
196+
MLIR_CAPI_EXPORTED intptr_t
197+
mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(MlirType type);
198+
199+
/// Returns the quantized dimension at the given position.
200+
MLIR_CAPI_EXPORTED int32_t
201+
mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(MlirType type,
202+
intptr_t pos);
203+
204+
/// Returns the block size at the given position.
205+
MLIR_CAPI_EXPORTED int64_t
206+
mlirUniformQuantizedSubChannelTypeGetBlockSize(MlirType type, intptr_t pos);
207+
208+
/// Returns the scales of the quantized type.
209+
MLIR_CAPI_EXPORTED MlirAttribute
210+
mlirUniformQuantizedSubChannelTypeGetScales(MlirType type);
211+
212+
/// Returns the zero-points of the quantized type.
213+
MLIR_CAPI_EXPORTED MlirAttribute
214+
mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type);
215+
175216
//===---------------------------------------------------------------------===//
176217
// CalibratedQuantizedType
177218
//===---------------------------------------------------------------------===//

mlir/lib/Bindings/Python/DialectQuant.cpp

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
#include <cstdint>
1010
#include <vector>
1111

12+
#include "mlir-c/BuiltinAttributes.h"
1213
#include "mlir-c/Dialect/Quant.h"
1314
#include "mlir-c/IR.h"
14-
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1515
#include "mlir/Bindings/Python/Nanobind.h"
16+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1617

1718
namespace nb = nanobind;
1819
using namespace llvm;
@@ -284,6 +285,79 @@ static void populateDialectQuantSubmodule(const nb::module_ &m) {
284285
},
285286
"Fixed point values are real numbers divided by a scale.");
286287

288+
//===-------------------------------------------------------------------===//
289+
// UniformQuantizedSubChannelType
290+
//===-------------------------------------------------------------------===//
291+
auto uniformQuantizedSubChannelType = mlir_type_subclass(
292+
m, "UniformQuantizedSubChannelType",
293+
mlirTypeIsAUniformQuantizedSubChannelType, quantizedType.get_class());
294+
uniformQuantizedSubChannelType.def_classmethod(
295+
"get",
296+
[](nb::object cls, unsigned flags, MlirType storageType,
297+
MlirType expressedType, MlirAttribute scales, MlirAttribute zeroPoints,
298+
std::vector<int32_t> quantizedDimensions,
299+
std::vector<int64_t> blockSizes, int64_t storageTypeMin,
300+
int64_t storageTypeMax) {
301+
return cls(mlirUniformQuantizedSubChannelTypeGet(
302+
flags, storageType, expressedType, scales, zeroPoints,
303+
static_cast<intptr_t>(blockSizes.size()),
304+
quantizedDimensions.data(), blockSizes.data(), storageTypeMin,
305+
storageTypeMax));
306+
},
307+
"Gets an instance of UniformQuantizedSubChannel in the same context as "
308+
"the provided storage type.",
309+
nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
310+
nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"),
311+
nb::arg("quantized_dimensions"), nb::arg("block_sizes"),
312+
nb::arg("storage_type_min"), nb::arg("storage_type_max"));
313+
uniformQuantizedSubChannelType.def_property_readonly(
314+
"quantized_dimensions",
315+
[](MlirType type) {
316+
intptr_t nDim =
317+
mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
318+
std::vector<int32_t> quantizedDimensions;
319+
quantizedDimensions.reserve(nDim);
320+
for (intptr_t i = 0; i < nDim; ++i) {
321+
quantizedDimensions.push_back(
322+
mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type, i));
323+
}
324+
return quantizedDimensions;
325+
},
326+
"Gets the quantized dimensions. Each element in the returned list "
327+
"represents an axis of the quantized data tensor that has a specified "
328+
"block size. The order of elements corresponds to the order of block "
329+
"sizes returned by 'block_sizes' method. It means that the data tensor "
330+
"is quantized along the i-th dimension in the returned list using the "
331+
"i-th block size from block_sizes method.");
332+
uniformQuantizedSubChannelType.def_property_readonly(
333+
"block_sizes",
334+
[](MlirType type) {
335+
intptr_t nDim =
336+
mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
337+
std::vector<int64_t> blockSizes;
338+
blockSizes.reserve(nDim);
339+
for (intptr_t i = 0; i < nDim; ++i) {
340+
blockSizes.push_back(
341+
mlirUniformQuantizedSubChannelTypeGetBlockSize(type, i));
342+
}
343+
return blockSizes;
344+
},
345+
"Gets the block sizes for the quantized dimensions. The i-th element in "
346+
"the returned list corresponds to the block size for the i-th dimension "
347+
"in the list returned by quantized_dimensions method.");
348+
uniformQuantizedSubChannelType.def_property_readonly(
349+
"scales",
350+
[](MlirType type) -> MlirAttribute {
351+
return mlirUniformQuantizedSubChannelTypeGetScales(type);
352+
},
353+
"The scales of the quantized type.");
354+
uniformQuantizedSubChannelType.def_property_readonly(
355+
"zero_points",
356+
[](MlirType type) -> MlirAttribute {
357+
return mlirUniformQuantizedSubChannelTypeGetZeroPoints(type);
358+
},
359+
"The zero points of the quantized type.");
360+
287361
//===-------------------------------------------------------------------===//
288362
// CalibratedQuantizedType
289363
//===-------------------------------------------------------------------===//

mlir/lib/CAPI/Dialect/Quant.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir-c/Dialect/Quant.h"
10+
#include "mlir-c/BuiltinAttributes.h"
1011
#include "mlir/CAPI/Registration.h"
1112
#include "mlir/Dialect/Quant/IR/Quant.h"
1213
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
@@ -194,6 +195,61 @@ bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) {
194195
return cast<quant::UniformQuantizedPerAxisType>(unwrap(type)).isFixedPoint();
195196
}
196197

198+
//===---------------------------------------------------------------------===//
199+
// UniformQuantizedSubChannelType
200+
//===---------------------------------------------------------------------===//
201+
202+
bool mlirTypeIsAUniformQuantizedSubChannelType(MlirType type) {
203+
return isa<quant::UniformQuantizedSubChannelType>(unwrap(type));
204+
}
205+
206+
MlirType mlirUniformQuantizedSubChannelTypeGet(
207+
unsigned flags, MlirType storageType, MlirType expressedType,
208+
MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr, intptr_t nDims,
209+
int32_t *quantizedDimensions, int64_t *blockSizes, int64_t storageTypeMin,
210+
int64_t storageTypeMax) {
211+
auto scales = dyn_cast<mlir::DenseElementsAttr>(unwrap(scalesAttr));
212+
auto zeroPoints = dyn_cast<mlir::DenseElementsAttr>(unwrap(zeroPointsAttr));
213+
214+
if (!scales || !zeroPoints) {
215+
return {};
216+
}
217+
218+
return wrap(quant::UniformQuantizedSubChannelType::get(
219+
flags, unwrap(storageType), unwrap(expressedType), scales, zeroPoints,
220+
llvm::ArrayRef<int32_t>(quantizedDimensions, nDims),
221+
llvm::ArrayRef<int64_t>(blockSizes, nDims), storageTypeMin,
222+
storageTypeMax));
223+
}
224+
225+
intptr_t mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(MlirType type) {
226+
return cast<quant::UniformQuantizedSubChannelType>(unwrap(type))
227+
.getBlockSizes()
228+
.size();
229+
}
230+
231+
int32_t mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(MlirType type,
232+
intptr_t pos) {
233+
return cast<quant::UniformQuantizedSubChannelType>(unwrap(type))
234+
.getQuantizedDimensions()[pos];
235+
}
236+
237+
int64_t mlirUniformQuantizedSubChannelTypeGetBlockSize(MlirType type,
238+
intptr_t pos) {
239+
return cast<quant::UniformQuantizedSubChannelType>(unwrap(type))
240+
.getBlockSizes()[pos];
241+
}
242+
243+
MlirAttribute mlirUniformQuantizedSubChannelTypeGetScales(MlirType type) {
244+
return wrap(
245+
cast<quant::UniformQuantizedSubChannelType>(unwrap(type)).getScales());
246+
}
247+
248+
MlirAttribute mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type) {
249+
return wrap(cast<quant::UniformQuantizedSubChannelType>(unwrap(type))
250+
.getZeroPoints());
251+
}
252+
197253
//===---------------------------------------------------------------------===//
198254
// CalibratedQuantizedType
199255
//===---------------------------------------------------------------------===//

mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

55

6-
from mlir.ir import Type
6+
from mlir.ir import DenseElementsAttr, Type
77

88
__all__ = [
99
"QuantizedType",
@@ -109,6 +109,26 @@ class UniformQuantizedPerAxisType(QuantizedType):
109109
@property
110110
def is_fixed_point(self) -> bool: ...
111111

112+
class UniformQuantizedSubChannelType(QuantizedType):
113+
114+
@classmethod
115+
def get(cls, flags: int, storage_type: Type, expressed_type: Type,
116+
scales: DenseElementsAttr, zero_points: DenseElementsAttr,
117+
quantized_dimensions: list[int], block_sizes: list[int],
118+
storage_type_min: int, storage_type_max: int):
119+
...
120+
121+
@property
122+
def quantized_dimensions(self) -> list[int]: ...
123+
124+
@property
125+
def block_sizes(self) -> list[int]: ...
126+
127+
@property
128+
def scales(self) -> DenseElementsAttr: ...
129+
130+
@property
131+
def zero_points(self) -> DenseElementsAttr: ...
112132

113133
def CalibratedQuantizedType(QuantizedType):
114134

0 commit comments

Comments
 (0)