Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def _process_quantization(
inv_perm = torch.argsort(perm)
output = output.index_select(-1, inv_perm)

else: # covers channel, token and tensor strategies
else: # covers tensor, channel, token, and attn_head strategies
if do_quantize:
output = _quantize(
x=x,
Expand Down
13 changes: 10 additions & 3 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


import logging
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch
from compressed_tensors.quantization import (
Expand Down Expand Up @@ -152,7 +152,7 @@ def initialize_qparams(
module: Module,
base_name: str,
quantization_args: QuantizationArgs,
observed_shape: Tuple[int],
observed_shape: Tuple[Union[int, None]],
observed_dtype: torch.dtype,
force_zero_point: bool = True,
):
Expand Down Expand Up @@ -199,7 +199,7 @@ def initialize_qparams(
expected_shape = (1,)

elif strategy == QuantizationStrategy.TOKEN:
expected_shape = (1, 1)
raise ValueError("Cannot perform static token quantization")

elif strategy == QuantizationStrategy.CHANNEL:
if len(observed_shape) < 2:
Expand Down Expand Up @@ -234,6 +234,13 @@ def initialize_qparams(
num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy)
expected_shape = (num_rows, num_cols)

elif strategy == QuantizationStrategy.ATTN_HEAD:
# (batch_size, num_attention_heads, seq_len, head_dim)
if len(observed_shape) < 3:
raise ValueError("Attention quant requires at least 3 observed dimensions")

expected_shape = (observed_shape[-3], 1, 1)

else:
assert False, f"Unknown strategy {strategy}"

Expand Down
1 change: 1 addition & 0 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class QuantizationStrategy(str, Enum):
BLOCK = "block"
TOKEN = "token"
TENSOR_GROUP = "tensor_group"
ATTN_HEAD = "attn_head"


class DynamicType(str, Enum):
Expand Down
1 change: 1 addition & 0 deletions src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
QuantizationStrategy.TENSOR,
QuantizationStrategy.GROUP,
QuantizationStrategy.TENSOR_GROUP,
QuantizationStrategy.ATTN_HEAD,
):
if (
inputs.strategy == QuantizationStrategy.GROUP
Expand Down
19 changes: 17 additions & 2 deletions tests/mock_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def flatten_for_quantization(


def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs):
# value.shape = (num_rows, num_cols)

if args.strategy == QuantizationStrategy.TENSOR:
# (1, 1, num_weight_elems)
return value.reshape((1, 1, -1))
Expand Down Expand Up @@ -110,10 +112,15 @@ def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs)
.unsqueeze(0)
)

if args.strategy == QuantizationStrategy.ATTN_HEAD:
raise ValueError("attention head quantization cannot be applied to weights")

assert False, f"Unknown strategy {args.strategy}"


def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationArgs):
# value.shape = (batch_size, seq_len, hidden_dim)

if args.strategy == QuantizationStrategy.TENSOR:
# (batch_size * seq_len, 1, hidden_dim)
return value.reshape((-1, 1, value.size(-1)))
Expand All @@ -134,14 +141,18 @@ def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationA
if args.strategy == QuantizationStrategy.BLOCK:
raise ValueError("Block quantization cannot be applied to activations")

if args.strategy == QuantizationStrategy.ATTN_HEAD:
raise ValueError("attention head quantization cannot be applied to linear acts")

assert False, f"Unknown strategy {args.strategy}"


def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationArgs):
# value.shape = (batch_size, num_heads, seq_len, head_dim)

if args.strategy == QuantizationStrategy.TENSOR:
# (batch_size, seq_len, num_heads, head_dim)
# (batch_size * seq_len, 1, num_heads * head_dim)
return value.flatten(0, 1).flatten(-2, -1).unsqueeze(-2)
return value.transpose(1, 2).flatten(0, 1).flatten(-2, -1).unsqueeze(-2)

if args.strategy == QuantizationStrategy.TOKEN:
raise ValueError("Token quantization cannot be applied to attention")
Expand All @@ -155,4 +166,8 @@ def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationAr
if args.strategy == QuantizationStrategy.BLOCK:
raise ValueError("Block quantization cannot be applied to attention")

if args.strategy == QuantizationStrategy.ATTN_HEAD:
# (batch_size * seq_len, num_heads, 1, 1, head_dim)
Copy link
Contributor

@brian-dellabetta brian-dellabetta Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we want the 1, 1 here?

return value.transpose(1, 2).flatten(0, 1).unsqueeze(-2).unsqueeze(-2)

assert False, f"Unknown strategy {args.strategy}"
63 changes: 50 additions & 13 deletions tests/test_quantization/lifecycle/test_static_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,45 +287,82 @@ class MockAttention(torch.nn.Module):
strategy="tensor",
),
torch.tensor([0.0]),
torch.tensor([11.0]),
torch.tensor([23.0]),
torch.tensor(
[
[
[[0.0000, 1.4688, 1.4688], [2.9375, 4.4062, 4.4062]],
[[5.8750, 7.3438, 7.3438], [8.8125, 10.2500, 10.2500]],
[
[0.0000, 0.0000, 3.0625, 3.0625],
[3.0625, 6.1250, 6.1250, 6.1250],
[9.1875, 9.1875, 9.1875, 12.2500],
],
[
[12.2500, 12.2500, 15.3125, 15.3125],
[15.3125, 18.3750, 18.3750, 18.3750],
[21.5000, 21.5000, 21.5000, 21.5000],
],
]
]
),
0.19,
0.81,
),
# static token is not supported
# channel is not supported
# group is not supported
# tensor group is not supported
# block is not supported
(
QuantizationArgs(
num_bits=4,
type="int",
symmetric=True,
strategy="attn_head",
),
torch.tensor([[[0.0]], [[12.0]]]),
torch.tensor([[[11.0]], [[23.0]]]),
torch.tensor(
[
[
[
[0.0000, 1.4688, 1.4688, 2.9375],
[4.4062, 4.4062, 5.8750, 7.3438],
[7.3438, 8.8125, 10.2500, 10.2500],
],
[
[12.2500, 12.2500, 15.3125, 15.3125],
[15.3125, 18.3750, 18.3750, 18.3750],
[21.5000, 21.5000, 21.5000, 21.5000],
],
]
]
),
0.55,
),
],
)
def test_static_attention_quantization(
args, exp_min_val, exp_max_val, exp_quant, exp_loss
):
"""
input = tensor([[[[ 0., 1., 2.],
[ 3., 4., 5.]],
input = tensor([[[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]],

[[ 6., 7., 8.],
[ 9., 10., 11.]]]])
[[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.]]]])
"""
# set up activation (and identity weight)
batch_size, seq_len, num_heads, head_dim = 1, 2, 2, 3
# set up attention
batch_size, num_heads, seq_len, head_dim = 1, 2, 3, 4
input = torch.arange(
(batch_size * seq_len * num_heads * head_dim), dtype=torch.bfloat16
).reshape((batch_size, seq_len, num_heads, head_dim))
(batch_size * num_heads * seq_len * head_dim), dtype=torch.bfloat16
).reshape((batch_size, num_heads, seq_len, head_dim))
attention = MockAttention()

# initialize quantization parameters
scheme = QuantizationScheme(targets=[], input_activations=args)
initialize_qparams(
attention, "k", args, (num_heads, head_dim), observed_dtype=torch.bfloat16
attention, "k", args, (num_heads, None, head_dim), observed_dtype=torch.bfloat16
)
attention.quantization_scheme = scheme
attention.quantization_status = QuantizationStatus.INITIALIZED
Expand Down