Skip to content

Commit e1ca4fd

Browse files
committed
fix shapes
Signed-off-by: Kyle Sayers <[email protected]>
1 parent c4a5cf4 commit e1ca4fd

File tree

4 files changed

+19
-16
lines changed

4 files changed

+19
-16
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def _process_quantization(
330330
inv_perm = torch.argsort(perm)
331331
output = output.index_select(-1, inv_perm)
332332

333-
else: # covers channel, token and tensor strategies
333+
else: # covers tensor, channel, token, and attn_head strategies
334334
if do_quantize:
335335
output = _quantize(
336336
x=x,

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
import logging
17-
from typing import Optional, Tuple
17+
from typing import Optional, Tuple, Union
1818

1919
import torch
2020
from compressed_tensors.quantization import (
@@ -152,7 +152,7 @@ def initialize_qparams(
152152
module: Module,
153153
base_name: str,
154154
quantization_args: QuantizationArgs,
155-
observed_shape: Tuple[int],
155+
observed_shape: Tuple[Union[int, None]],
156156
observed_dtype: torch.dtype,
157157
force_zero_point: bool = True,
158158
):
@@ -199,7 +199,7 @@ def initialize_qparams(
199199
expected_shape = (1,)
200200

201201
elif strategy == QuantizationStrategy.TOKEN:
202-
expected_shape = (1, 1)
202+
raise ValueError("Cannot perform static token quantization")
203203

204204
elif strategy == QuantizationStrategy.CHANNEL:
205205
if len(observed_shape) < 2:
@@ -235,10 +235,11 @@ def initialize_qparams(
235235
expected_shape = (num_rows, num_cols)
236236

237237
elif strategy == QuantizationStrategy.ATTN_HEAD:
238-
if len(observed_shape) < 2:
239-
raise ValueError("Attention quant requires at least 2 observed dimensions")
238+
# (batch_size, num_attention_heads, seq_len, head_dim)
239+
if len(observed_shape) < 3:
240+
raise ValueError("Attention quant requires at least 3 observed dimensions")
240241

241-
expected_shape = (observed_shape[-2], 1)
242+
expected_shape = (observed_shape[-3], 1, 1)
242243

243244
else:
244245
assert False, f"Unknown strategy {strategy}"

tests/mock_observer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationAr
162162
raise ValueError("Block quantization cannot be applied to attention")
163163

164164
if args.strategy == QuantizationStrategy.ATTN_HEAD:
165-
# (batch_size * seq_len, num_heads, 1, head_dim)
166-
return value.flatten(0, 1).unsqueeze(-2)
165+
# (batch_size * seq_len, num_heads, 1, 1, head_dim)
166+
return value.transpose(1, 2).flatten(0, 1).unsqueeze(-2).unsqueeze(-2)
167167

168168
assert False, f"Unknown strategy {args.strategy}"

tests/test_quantization/lifecycle/test_static_lifecycle.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -310,17 +310,17 @@ class MockAttention(torch.nn.Module):
310310
symmetric=True,
311311
strategy="attn_head",
312312
),
313-
torch.tensor([[0], [3]]),
314-
torch.tensor([[8], [11]]),
313+
torch.tensor([[[0.0]], [[6.0]]]),
314+
torch.tensor([[[5.0]], [[11.0]]]),
315315
torch.tensor(
316316
[
317317
[
318-
[[0.0000, 1.0703, 2.1406], [2.9375, 4.4062, 4.4062]],
319-
[[6.4375, 7.5000, 7.5000], [8.8125, 10.2500, 10.2500]],
318+
[[0.0000, 1.3359, 2.0000], [2.6719, 4.0000, 4.6875]],
319+
[[5.8750, 7.3438, 7.3438], [8.8125, 10.2500, 10.2500]],
320320
]
321321
]
322322
),
323-
0.16,
323+
0.13,
324324
),
325325
],
326326
)
@@ -335,7 +335,7 @@ def test_static_attention_quantization(
335335
[ 9., 10., 11.]]]])
336336
"""
337337
# set up activation (and identity weight)
338-
batch_size, seq_len, num_heads, head_dim = 1, 2, 2, 3
338+
batch_size, num_heads, seq_len, head_dim = 1, 2, 2, 3
339339
input = torch.arange(
340340
(batch_size * seq_len * num_heads * head_dim), dtype=torch.bfloat16
341341
).reshape((batch_size, seq_len, num_heads, head_dim))
@@ -344,7 +344,7 @@ def test_static_attention_quantization(
344344
# initialize quantization parameters
345345
scheme = QuantizationScheme(targets=[], input_activations=args)
346346
initialize_qparams(
347-
attention, "k", args, (num_heads, head_dim), observed_dtype=torch.bfloat16
347+
attention, "k", args, (num_heads, None, head_dim), observed_dtype=torch.bfloat16
348348
)
349349
attention.quantization_scheme = scheme
350350
attention.quantization_status = QuantizationStatus.INITIALIZED
@@ -366,5 +366,7 @@ def test_static_attention_quantization(
366366
assert torch.equal(attention.k_observer.max_vals, exp_max_val)
367367

368368
# check forward pass
369+
print(output)
370+
print(torch.nn.functional.mse_loss(output, input))
369371
assert torch.allclose(output, exp_quant.to(output.dtype))
370372
assert torch.nn.functional.mse_loss(output, input) <= exp_loss

0 commit comments

Comments
 (0)