|
15 | 15 |
|
16 | 16 | import logging
|
17 | 17 | import math
|
| 18 | +import warnings |
18 | 19 | from enum import Enum
|
19 | 20 | from typing import List, Optional
|
20 | 21 |
|
@@ -172,14 +173,41 @@ def _initialize_scale_zero_point(
|
172 | 173 |
|
173 | 174 | if base_name == "weight" and weight_shape is not None:
|
174 | 175 | if quantization_args.strategy == QuantizationStrategy.CHANNEL:
|
175 |
| - # (output_channels, 1) |
| 176 | + # (output_channels, 1) - only for weights |
176 | 177 | expected_shape = (weight_shape[0], 1)
|
177 | 178 | elif quantization_args.strategy in (
|
178 | 179 | QuantizationStrategy.TENSOR_GROUP,
|
179 | 180 | QuantizationStrategy.GROUP,
|
180 | 181 | ):
|
| 182 | + # GROUP/TENSOR_GROUP for both weights and activations |
181 | 183 | num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
|
182 | 184 | expected_shape = (weight_shape[0], max(num_groups, 1))
|
| 185 | + elif quantization_args.strategy == QuantizationStrategy.BLOCK: |
| 186 | + # For block quantization, scale shape should match number of blocks - only for weights |
| 187 | + if quantization_args.block_structure is None: |
| 188 | + raise ValueError("Block quantization requires block_structure to be specified") |
| 189 | + block_height, block_width = quantization_args.block_structure |
| 190 | + rows, cols = weight_shape[-2], weight_shape[-1] |
| 191 | + num_rows_blocks = math.ceil(rows / block_height) |
| 192 | + num_cols_blocks = math.ceil(cols / block_width) |
| 193 | + |
| 194 | + # Warn if dimensions don't divide evenly |
| 195 | + if rows % block_height != 0 or cols % block_width != 0: |
| 196 | + warnings.warn( |
| 197 | + f"Block quantization: tensor shape {weight_shape} does not divide evenly " |
| 198 | + f"by block structure {quantization_args.block_structure}. " |
| 199 | + f"Some blocks will be incomplete which may affect quantization quality.", |
| 200 | + UserWarning |
| 201 | + ) |
| 202 | + |
| 203 | + expected_shape = (num_rows_blocks, num_cols_blocks) |
| 204 | + elif quantization_args.strategy == QuantizationStrategy.BLOCK: |
| 205 | + warnings.warn( |
| 206 | + f"BLOCK quantization not supported for {base_name} activations. " |
| 207 | + f"Falling back to tensor-level quantization.", |
| 208 | + UserWarning |
| 209 | + ) |
| 210 | + expected_shape = 1 |
183 | 211 |
|
184 | 212 | # 3. Identify quantization scale and zp dtype
|
185 | 213 | scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype
|
|
0 commit comments