Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit f517fe2

Browse files
authored
Sub-channel quantized type implementation (#120172)
This is an implementation for [RFC: Supporting Sub-Channel Quantization in MLIR](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694). In order to make the review process easier, the PR has been divided into the following commit labels: 1. **Add implementation for sub-channel type:** Includes the class design for `UniformQuantizedSubChannelType`, printer/parser and bytecode read/write support. The existing types (per-tensor and per-axis) are unaltered. 2. **Add implementation for sub-channel type:** Lowering of `quant.qcast` and `quant.dcast` operations to Linalg operations. 3. **Adding C/Python Apis:** We first define he C-APIs and build the Python-APIs on top of those. 4. **Add pass to normalize generic ....:** This pass normalizes sub-channel quantized types to per-tensor per-axis types, if possible. A design note: - **Explicitly storing the `quantized_dimensions`, even when they can be derived for ranked tensor.** While it's possible to infer quantized dimensions from the static shape of the scales (or zero-points) tensor for ranked data tensors ([ref](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694/3) for background), there are cases where this can lead to ambiguity and issues with round-tripping. ``` Consider the example: tensor<2x4x!quant.uniform<i8:f32:{0:2, 0:2}, {{s00:z00, s01:z01}}>> ``` The shape of the scales tensor is [1, 2], which might suggest that only axis 1 is quantized. While this inference is technically correct, as the block size for axis 0 is a degenerate case (equal to the dimension size), it can cause problems with round-tripping. Therefore, even for ranked tensors, we are explicitly storing the quantized dimensions. Suggestions welcome! PS: I understand that the upcoming holidays may impact your schedule, so please take your time with the review. There's no rush.
1 parent a09c380 commit f517fe2

File tree

1 file changed

+75
-1
lines changed

1 file changed

+75
-1
lines changed

mlir/lib/Bindings/Python/DialectQuant.cpp

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
#include <cstdint>
1010
#include <vector>
1111

12+
#include "mlir-c/BuiltinAttributes.h"
1213
#include "mlir-c/Dialect/Quant.h"
1314
#include "mlir-c/IR.h"
14-
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1515
#include "mlir/Bindings/Python/Nanobind.h"
16+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1617

1718
namespace nb = nanobind;
1819
using namespace llvm;
@@ -284,6 +285,79 @@ static void populateDialectQuantSubmodule(const nb::module_ &m) {
284285
},
285286
"Fixed point values are real numbers divided by a scale.");
286287

288+
//===-------------------------------------------------------------------===//
289+
// UniformQuantizedSubChannelType
290+
//===-------------------------------------------------------------------===//
291+
auto uniformQuantizedSubChannelType = mlir_type_subclass(
292+
m, "UniformQuantizedSubChannelType",
293+
mlirTypeIsAUniformQuantizedSubChannelType, quantizedType.get_class());
294+
uniformQuantizedSubChannelType.def_classmethod(
295+
"get",
296+
[](nb::object cls, unsigned flags, MlirType storageType,
297+
MlirType expressedType, MlirAttribute scales, MlirAttribute zeroPoints,
298+
std::vector<int32_t> quantizedDimensions,
299+
std::vector<int64_t> blockSizes, int64_t storageTypeMin,
300+
int64_t storageTypeMax) {
301+
return cls(mlirUniformQuantizedSubChannelTypeGet(
302+
flags, storageType, expressedType, scales, zeroPoints,
303+
static_cast<intptr_t>(blockSizes.size()),
304+
quantizedDimensions.data(), blockSizes.data(), storageTypeMin,
305+
storageTypeMax));
306+
},
307+
"Gets an instance of UniformQuantizedSubChannel in the same context as "
308+
"the provided storage type.",
309+
nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
310+
nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"),
311+
nb::arg("quantized_dimensions"), nb::arg("block_sizes"),
312+
nb::arg("storage_type_min"), nb::arg("storage_type_max"));
313+
uniformQuantizedSubChannelType.def_property_readonly(
314+
"quantized_dimensions",
315+
[](MlirType type) {
316+
intptr_t nDim =
317+
mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
318+
std::vector<int32_t> quantizedDimensions;
319+
quantizedDimensions.reserve(nDim);
320+
for (intptr_t i = 0; i < nDim; ++i) {
321+
quantizedDimensions.push_back(
322+
mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type, i));
323+
}
324+
return quantizedDimensions;
325+
},
326+
"Gets the quantized dimensions. Each element in the returned list "
327+
"represents an axis of the quantized data tensor that has a specified "
328+
"block size. The order of elements corresponds to the order of block "
329+
"sizes returned by 'block_sizes' method. It means that the data tensor "
330+
"is quantized along the i-th dimension in the returned list using the "
331+
"i-th block size from block_sizes method.");
332+
uniformQuantizedSubChannelType.def_property_readonly(
333+
"block_sizes",
334+
[](MlirType type) {
335+
intptr_t nDim =
336+
mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
337+
std::vector<int64_t> blockSizes;
338+
blockSizes.reserve(nDim);
339+
for (intptr_t i = 0; i < nDim; ++i) {
340+
blockSizes.push_back(
341+
mlirUniformQuantizedSubChannelTypeGetBlockSize(type, i));
342+
}
343+
return blockSizes;
344+
},
345+
"Gets the block sizes for the quantized dimensions. The i-th element in "
346+
"the returned list corresponds to the block size for the i-th dimension "
347+
"in the list returned by quantized_dimensions method.");
348+
uniformQuantizedSubChannelType.def_property_readonly(
349+
"scales",
350+
[](MlirType type) -> MlirAttribute {
351+
return mlirUniformQuantizedSubChannelTypeGetScales(type);
352+
},
353+
"The scales of the quantized type.");
354+
uniformQuantizedSubChannelType.def_property_readonly(
355+
"zero_points",
356+
[](MlirType type) -> MlirAttribute {
357+
return mlirUniformQuantizedSubChannelTypeGetZeroPoints(type);
358+
},
359+
"The zero points of the quantized type.");
360+
287361
//===-------------------------------------------------------------------===//
288362
// CalibratedQuantizedType
289363
//===-------------------------------------------------------------------===//

0 commit comments

Comments
 (0)