Skip to content

Commit 8973328

Browse files
committed
remove attn head
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 01357dc commit 8973328

File tree

2 files changed

+0
-29
lines changed

2 files changed

+0
-29
lines changed

tests/observer.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,6 @@ def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs)
158158
.unsqueeze(0)
159159
)
160160

161-
if args.strategy == QuantizationStrategy.ATTN_HEAD:
162-
raise ValueError("attention head quantization cannot be applied to weights")
163-
164161
assert False, f"Unknown strategy {args.strategy}"
165162

166163

@@ -185,9 +182,6 @@ def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationA
185182
if args.strategy == QuantizationStrategy.BLOCK:
186183
raise ValueError("Block quantization cannot be applied to activations")
187184

188-
if args.strategy == QuantizationStrategy.ATTN_HEAD:
189-
raise ValueError("attention head quantization cannot be applied to linear acts")
190-
191185
assert False, f"Unknown strategy {args.strategy}"
192186

193187

@@ -209,8 +203,4 @@ def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationAr
209203
if args.strategy == QuantizationStrategy.BLOCK:
210204
raise ValueError("Block quantization cannot be applied to attention")
211205

212-
if args.strategy == QuantizationStrategy.ATTN_HEAD:
213-
# (batch_size * seq_len, num_heads, 1, head_dim)
214-
return value.flatten(0, 1).unsqueeze(-2)
215-
216206
assert False, f"Unknown strategy {args.strategy}"

tests/test_quantization/lifecycle/test_static_lifecycle.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -302,25 +302,6 @@ class MockAttention(torch.nn.Module):
302302
# group is not supported
303303
# tensor group is not supported
304304
# block is not supported
305-
(
306-
QuantizationArgs(
307-
num_bits=4,
308-
type="int",
309-
symmetric=True,
310-
strategy="attn_head",
311-
),
312-
torch.tensor([[0], [3]]),
313-
torch.tensor([[8], [11]]),
314-
torch.tensor(
315-
[
316-
[
317-
[[0.0000, 1.0703, 2.1406], [2.9375, 4.4062, 4.4062]],
318-
[[6.4375, 7.5000, 7.5000], [8.8125, 10.2500, 10.2500]],
319-
]
320-
]
321-
),
322-
0.16,
323-
),
324305
],
325306
)
326307
def test_static_attention_quantization(

0 commit comments

Comments
 (0)