Skip to content

Commit 199f274

Browse files
committed
activations have one row
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 6672617 commit 199f274

File tree

4 files changed

+41
-41
lines changed

4 files changed

+41
-41
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -280,17 +280,8 @@ def _process_quantization(
280280
f"by the given group_size {group_size}"
281281
)
282282

283-
# support column-order (default) quantization as well as other orderings
284-
# such as activation ordering. Below checks if g_idx has been initialized
285-
is_column_order = g_idx is None or -1 in g_idx
286-
if is_column_order:
287-
num_groups = int(ceil(columns / group_size))
288-
group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
289-
290-
else:
291-
group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
292-
group_sizes = group_sizes[torch.argsort(group_indices)]
293-
283+
# permute groups
284+
if g_idx is not None:
294285
perm = torch.argsort(g_idx)
295286
x = x.index_select(-1, perm)
296287

@@ -299,6 +290,8 @@ def _process_quantization(
299290
ceil(x.shape[-1] / group_size),
300291
group_size,
301292
)
293+
# we should potentially be folding reshaped_dims[0] into x.shape[-2]
294+
# in order to allow for multi-headed activations
302295
x = x.unflatten(-1, reshaped_dims)
303296

304297
if do_quantize:
@@ -325,9 +318,9 @@ def _process_quantization(
325318
output = output.flatten(-2, -1)
326319
output = output.to(output_dtype)
327320

328-
if not is_column_order:
329-
inv_perm = torch.argsort(perm)
330-
output = output.index_select(-1, inv_perm)
321+
# unpermute groups
322+
if g_idx is not None:
323+
x = x.index_select(-1, g_idx)
331324

332325
else: # covers channel, token and tensor strategies
333326
if do_quantize:

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,8 @@
1414

1515

1616
import logging
17-
import math
18-
import warnings
1917
from enum import Enum
20-
from typing import Optional
18+
from typing import Optional, Tuple
2119

2220
import torch
2321
from compressed_tensors.quantization.lifecycle.forward import (
@@ -32,7 +30,11 @@
3230
)
3331
from compressed_tensors.quantization.quant_config import QuantizationStatus
3432
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
35-
from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme
33+
from compressed_tensors.quantization.utils import (
34+
is_fp4,
35+
is_kv_cache_quant_scheme,
36+
strict_divide,
37+
)
3638
from compressed_tensors.utils import (
3739
disable_hf_hook,
3840
get_execution_device,
@@ -102,7 +104,7 @@ def initialize_module_for_quantization(
102104
if scheme.input_activations is not None:
103105
base_name = "input"
104106
args = scheme.input_activations
105-
observed_shape = weight.shape[-1:]
107+
observed_shape = (1, weight.size(-1))
106108
observed_dtype = weight.dtype
107109

108110
if scheme.weights is not None:
@@ -148,7 +150,7 @@ def _initialize_scale_zero_point(
148150
module: Module,
149151
base_name: str,
150152
quantization_args: QuantizationArgs,
151-
observed_shape: torch.Size,
153+
observed_shape: Tuple[int],
152154
observed_dtype: torch.dtype,
153155
force_zero_point: bool = True,
154156
):
@@ -191,8 +193,8 @@ def _initialize_scale_zero_point(
191193
raise ValueError("Group quant requires at least 1 observed dimension")
192194

193195
group_size = quantization_args.group_size
194-
num_groups = _strict_divide(observed_shape[-1], group_size, strategy)
195-
expected_shape = (num_groups, group_size)
196+
num_groups = strict_divide(observed_shape[-1], group_size, strategy)
197+
expected_shape = (*observed_shape[:-1], num_groups)
196198

197199
# initialize activation ordering if applicable
198200
if actorder == ActivationOrdering.GROUP:
@@ -208,8 +210,8 @@ def _initialize_scale_zero_point(
208210
raise ValueError("Block quant requires at least 2 observed dimensions")
209211

210212
block_structure = quantization_args.block_structure
211-
num_rows = _strict_divide(observed_shape[-2], block_structure[-2], strategy)
212-
num_cols = _strict_divide(observed_shape[-1], block_structure[-1], strategy)
213+
num_rows = strict_divide(observed_shape[-2], block_structure[-2], strategy)
214+
num_cols = strict_divide(observed_shape[-1], block_structure[-1], strategy)
213215
expected_shape = (num_rows, num_cols)
214216

215217
# 2. Identify quantization scale and zp dtype
@@ -264,16 +266,3 @@ def _initialize_attn_scales(module: Module) -> None:
264266
requires_grad=False,
265267
)
266268
register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale)
267-
268-
269-
def _strict_divide(observed: int, divisor: int, strategy: QuantizationStrategy) -> int:
270-
out = observed // divisor
271-
if out * divisor != observed:
272-
raise ValueError(
273-
f"{strategy} quantization strategy requires strict division of "
274-
f"weight/activation size {observed} and group/block size {divisor}. "
275-
"consider reducing the group/block size or ignoring modules with weights "
276-
f"not divisible by {divisor}"
277-
)
278-
279-
return out

src/compressed_tensors/quantization/quant_args.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,9 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
283283
has_block_structure = block_structure is not None
284284
if has_block_strategy != has_block_structure:
285285
raise ValueError(
286-
"Block strategy requires `block_structure`, and vice versa. "
287-
f"Instead got ({strategy}, {block_structure})"
286+
"`strategy = block` requires `block_structure != None`, and vice versa."
287+
f" Instead got `strategy={strategy}` and "
288+
f"`block_structure={block_structure}`"
288289
)
289290

290291
# validate group strategy
@@ -296,8 +297,8 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
296297
has_actorder = actorder is not None
297298
if has_group_strategy != has_group_size:
298299
raise ValueError(
299-
"Group strategies require `group_size`, and vice versa. "
300-
f"Instead got ({strategy}, {group_size})"
300+
"`strategy = group` requires `group_size != None`, and vice versa. "
301+
f"Instead got `strategy={strategy}` and `group_size={group_size}`"
301302
)
302303
if has_actorder and not has_group_strategy:
303304
raise ValueError(

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
"calculate_qparams",
4949
"generate_gparam",
5050
"is_fp4",
51+
"strict_divide",
5152
]
5253

5354
# target the self_attn layer
@@ -477,3 +478,19 @@ def generate_gparam(
477478
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
478479
global_scale = scale_data.max * quant_data.max / max_val_pos
479480
return global_scale.to(dtype).reshape([1])
481+
482+
483+
def strict_divide(
484+
observed: int, divisor: int, strategy: Optional[QuantizationStrategy] = None
485+
) -> int:
486+
out = observed // divisor
487+
if out * divisor != observed:
488+
if strategy is not None:
489+
raise ValueError(
490+
f"{strategy} quantization strategy requires strict division of "
491+
f"weight/activation size {observed} and group/block size {divisor}. "
492+
"consider reducing the group/block size or ignoring modules with "
493+
f"weights not divisible by {divisor}"
494+
)
495+
496+
return out

0 commit comments

Comments
 (0)