diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index c9430e9ec..8da18e721 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -167,6 +167,16 @@ def _initialize_scale_zero_point( # 2. Infer expected scale/zero point shape if quantization_args.strategy == QuantizationStrategy.TOKEN: expected_shape = (1, 1) + elif quantization_args.strategy == QuantizationStrategy.ATTN_HEAD: + # supports only GQA models, support others when/if needed + if base_name == "q": + expected_shape = module.config.num_attention_heads + elif base_name in ("k", "v"): + expected_shape = module.config.num_key_value_heads + else: + raise ValueError( + f"Unsupported target {type(module)} for per-attention-head quantization" + ) else: expected_shape = 1 diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index d9e88353b..5b6e23ee0 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -101,6 +101,7 @@ class QuantizationStrategy(str, Enum): BLOCK = "block" TOKEN = "token" TENSOR_GROUP = "tensor_group" + ATTN_HEAD = "attn_head" class DynamicType(str, Enum):