Skip to content

Commit 5478b43

Browse files
authored
fix block quantization initialization (#403)
* fix block quantization initialization Signed-off-by: shanjiaz <[email protected]> * fix style Signed-off-by: shanjiaz <[email protected]> * fix style Signed-off-by: shanjiaz <[email protected]> * fix style Signed-off-by: shanjiaz <[email protected]> --------- Signed-off-by: shanjiaz <[email protected]>
1 parent 17a746c commit 5478b43

File tree

2 files changed

+42
-4
lines changed

2 files changed

+42
-4
lines changed

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717
import math
18+
import warnings
1819
from enum import Enum
1920
from typing import List, Optional
2021

@@ -172,14 +173,41 @@ def _initialize_scale_zero_point(
172173

173174
if base_name == "weight" and weight_shape is not None:
174175
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
175-
# (output_channels, 1)
176+
# (output_channels, 1) - only for weights
176177
expected_shape = (weight_shape[0], 1)
177178
elif quantization_args.strategy in (
178179
QuantizationStrategy.TENSOR_GROUP,
179180
QuantizationStrategy.GROUP,
180181
):
182+
# GROUP/TENSOR_GROUP for both weights and activations
181183
num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
182184
expected_shape = (weight_shape[0], max(num_groups, 1))
185+
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
186+
# For block quantization, scale shape should match number of blocks - only for weights
187+
if quantization_args.block_structure is None:
188+
raise ValueError("Block quantization requires block_structure to be specified")
189+
block_height, block_width = quantization_args.block_structure
190+
rows, cols = weight_shape[-2], weight_shape[-1]
191+
num_rows_blocks = math.ceil(rows / block_height)
192+
num_cols_blocks = math.ceil(cols / block_width)
193+
194+
# Warn if dimensions don't divide evenly
195+
if rows % block_height != 0 or cols % block_width != 0:
196+
warnings.warn(
197+
f"Block quantization: tensor shape {weight_shape} does not divide evenly "
198+
f"by block structure {quantization_args.block_structure}. "
199+
f"Some blocks will be incomplete which may affect quantization quality.",
200+
UserWarning
201+
)
202+
203+
expected_shape = (num_rows_blocks, num_cols_blocks)
204+
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
205+
warnings.warn(
206+
f"BLOCK quantization not supported for {base_name} activations. "
207+
f"Falling back to tensor-level quantization.",
208+
UserWarning
209+
)
210+
expected_shape = 1
183211

184212
# 3. Identify quantization scale and zp dtype
185213
scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype

tests/test_quantization/lifecycle/test_initialize.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,8 @@ def test_initialize_module_for_quantization_offloaded(
174174
),
175175
),
176176
(
177-
QuantizationArgs(strategy="block"),
178-
QuantizationArgs(strategy="block"),
177+
QuantizationArgs(strategy="block", block_structure=[2, 4]),
178+
None,
179179
),
180180
(
181181
QuantizationArgs(strategy="token"),
@@ -227,7 +227,17 @@ def test_initialize_quantization_parameters(weights, input_activations):
227227
expected_shape = (layer.weight.shape[0], max(num_groups, 1))
228228

229229
elif args.strategy == QuantizationStrategy.BLOCK:
230-
expected_shape = (1,)
230+
# For block quantization, only weights get block-level scales
231+
# Activations fall back to tensor-level since shape is unknown at init
232+
if q_type == "weights" and args.block_structure is not None:
233+
block_height, block_width = args.block_structure
234+
rows, cols = layer.weight.shape[-2], layer.weight.shape[-1]
235+
num_rows_blocks = math.ceil(rows / block_height)
236+
num_cols_blocks = math.ceil(cols / block_width)
237+
expected_shape = (num_rows_blocks, num_cols_blocks)
238+
else:
239+
# For activations or when block_structure is None
240+
expected_shape = (1,)
231241

232242
elif args.strategy == QuantizationStrategy.TOKEN:
233243
expected_shape = (1, 1)

0 commit comments

Comments
 (0)