|
9 | 9 | #include <cstdint> |
10 | 10 | #include <vector> |
11 | 11 |
|
| 12 | +#include "mlir-c/BuiltinAttributes.h" |
12 | 13 | #include "mlir-c/Dialect/Quant.h" |
13 | 14 | #include "mlir-c/IR.h" |
14 | | -#include "mlir/Bindings/Python/NanobindAdaptors.h" |
15 | 15 | #include "mlir/Bindings/Python/Nanobind.h" |
| 16 | +#include "mlir/Bindings/Python/NanobindAdaptors.h" |
16 | 17 |
|
17 | 18 | namespace nb = nanobind; |
18 | 19 | using namespace llvm; |
@@ -284,6 +285,79 @@ static void populateDialectQuantSubmodule(const nb::module_ &m) { |
284 | 285 | }, |
285 | 286 | "Fixed point values are real numbers divided by a scale."); |
286 | 287 |
|
| 288 | + //===-------------------------------------------------------------------===// |
| 289 | + // UniformQuantizedSubChannelType |
| 290 | + //===-------------------------------------------------------------------===// |
| 291 | + auto uniformQuantizedSubChannelType = mlir_type_subclass( |
| 292 | + m, "UniformQuantizedSubChannelType", |
| 293 | + mlirTypeIsAUniformQuantizedSubChannelType, quantizedType.get_class()); |
| 294 | + uniformQuantizedSubChannelType.def_classmethod( |
| 295 | + "get", |
| 296 | + [](nb::object cls, unsigned flags, MlirType storageType, |
| 297 | + MlirType expressedType, MlirAttribute scales, MlirAttribute zeroPoints, |
| 298 | + std::vector<int32_t> quantizedDimensions, |
| 299 | + std::vector<int64_t> blockSizes, int64_t storageTypeMin, |
| 300 | + int64_t storageTypeMax) { |
| 301 | + return cls(mlirUniformQuantizedSubChannelTypeGet( |
| 302 | + flags, storageType, expressedType, scales, zeroPoints, |
| 303 | + static_cast<intptr_t>(blockSizes.size()), |
| 304 | + quantizedDimensions.data(), blockSizes.data(), storageTypeMin, |
| 305 | + storageTypeMax)); |
| 306 | + }, |
| 307 | + "Gets an instance of UniformQuantizedSubChannel in the same context as " |
| 308 | + "the provided storage type.", |
| 309 | + nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"), |
| 310 | + nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"), |
| 311 | + nb::arg("quantized_dimensions"), nb::arg("block_sizes"), |
| 312 | + nb::arg("storage_type_min"), nb::arg("storage_type_max")); |
| 313 | + uniformQuantizedSubChannelType.def_property_readonly( |
| 314 | + "quantized_dimensions", |
| 315 | + [](MlirType type) { |
| 316 | + intptr_t nDim = |
| 317 | + mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type); |
| 318 | + std::vector<int32_t> quantizedDimensions; |
| 319 | + quantizedDimensions.reserve(nDim); |
| 320 | + for (intptr_t i = 0; i < nDim; ++i) { |
| 321 | + quantizedDimensions.push_back( |
| 322 | + mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type, i)); |
| 323 | + } |
| 324 | + return quantizedDimensions; |
| 325 | + }, |
| 326 | + "Gets the quantized dimensions. Each element in the returned list " |
| 327 | + "represents an axis of the quantized data tensor that has a specified " |
| 328 | + "block size. The order of elements corresponds to the order of block " |
| 329 | + "sizes returned by 'block_sizes' method. It means that the data tensor " |
| 330 | + "is quantized along the i-th dimension in the returned list using the " |
| 331 | + "i-th block size from block_sizes method."); |
| 332 | + uniformQuantizedSubChannelType.def_property_readonly( |
| 333 | + "block_sizes", |
| 334 | + [](MlirType type) { |
| 335 | + intptr_t nDim = |
| 336 | + mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type); |
| 337 | + std::vector<int64_t> blockSizes; |
| 338 | + blockSizes.reserve(nDim); |
| 339 | + for (intptr_t i = 0; i < nDim; ++i) { |
| 340 | + blockSizes.push_back( |
| 341 | + mlirUniformQuantizedSubChannelTypeGetBlockSize(type, i)); |
| 342 | + } |
| 343 | + return blockSizes; |
| 344 | + }, |
| 345 | + "Gets the block sizes for the quantized dimensions. The i-th element in " |
| 346 | + "the returned list corresponds to the block size for the i-th dimension " |
| 347 | + "in the list returned by quantized_dimensions method."); |
| 348 | + uniformQuantizedSubChannelType.def_property_readonly( |
| 349 | + "scales", |
| 350 | + [](MlirType type) -> MlirAttribute { |
| 351 | + return mlirUniformQuantizedSubChannelTypeGetScales(type); |
| 352 | + }, |
| 353 | + "The scales of the quantized type."); |
| 354 | + uniformQuantizedSubChannelType.def_property_readonly( |
| 355 | + "zero_points", |
| 356 | + [](MlirType type) -> MlirAttribute { |
| 357 | + return mlirUniformQuantizedSubChannelTypeGetZeroPoints(type); |
| 358 | + }, |
| 359 | + "The zero points of the quantized type."); |
| 360 | + |
287 | 361 | //===-------------------------------------------------------------------===// |
288 | 362 | // CalibratedQuantizedType |
289 | 363 | //===-------------------------------------------------------------------===// |
|
0 commit comments