diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 4b896d37..390b174a 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -234,6 +234,12 @@ def initialize_qparams( num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy) expected_shape = (num_rows, num_cols) + elif strategy == QuantizationStrategy.ATTN_HEAD: + if len(observed_shape) < 2: + raise ValueError("Attention quant requires at least 2 observed dimensions") + + expected_shape = (observed_shape[-2], 1) + else: assert False, f"Unknown strategy {strategy}" diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index d9e88353..5b6e23ee 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -101,6 +101,7 @@ class QuantizationStrategy(str, Enum): BLOCK = "block" TOKEN = "token" TENSOR_GROUP = "tensor_group" + ATTN_HEAD = "attn_head" class DynamicType(str, Enum): diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index b11e3c0c..1e3e089d 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -65,6 +65,7 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": QuantizationStrategy.TENSOR, QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP, + QuantizationStrategy.ATTN_HEAD, ): if ( inputs.strategy == QuantizationStrategy.GROUP diff --git a/tests/observer.py b/tests/observer.py index 290153c0..b30d19fa 100644 --- a/tests/observer.py +++ b/tests/observer.py @@ -158,6 +158,9 @@ def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs) .unsqueeze(0) ) + if args.strategy == QuantizationStrategy.ATTN_HEAD: + raise ValueError("attention head quantization cannot be applied to weights") + assert False, f"Unknown strategy {args.strategy}" @@ -182,6 +185,9 @@ def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationA if args.strategy == QuantizationStrategy.BLOCK: raise ValueError("Block quantization cannot be applied to activations") + if args.strategy == QuantizationStrategy.ATTN_HEAD: + raise ValueError("attention head quantization cannot be applied to linear acts") + assert False, f"Unknown strategy {args.strategy}" @@ -203,4 +209,8 @@ def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationAr if args.strategy == QuantizationStrategy.BLOCK: raise ValueError("Block quantization cannot be applied to attention") + if args.strategy == QuantizationStrategy.ATTN_HEAD: + # (batch_size * seq_len, num_heads, 1, head_dim) + return value.flatten(0, 1).unsqueeze(-2) + assert False, f"Unknown strategy {args.strategy}" diff --git a/tests/test_quantization/lifecycle/test_static_lifecycle.py b/tests/test_quantization/lifecycle/test_static_lifecycle.py index 4adcba98..efc17aec 100644 --- a/tests/test_quantization/lifecycle/test_static_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_static_lifecycle.py @@ -302,6 +302,25 @@ class MockAttention(torch.nn.Module): # group is not supported # tensor group is not supported # block is not supported + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="attn_head", + ), + torch.tensor([[0], [3]]), + torch.tensor([[8], [11]]), + torch.tensor( + [ + [ + [[0.0000, 1.0703, 2.1406], [2.9375, 4.4062, 4.4062]], + [[6.4375, 7.5000, 7.5000], [8.8125, 10.2500, 10.2500]], + ] + ] + ), + 0.16, + ), ], ) def test_static_attention_quantization(