Skip to content

Commit 2ea692d

Browse files
committed
add tests
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 11d15d5 commit 2ea692d

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

tests/observer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs)
158158
.unsqueeze(0)
159159
)
160160

161+
if args.strategy == QuantizationStrategy.ATTN_HEAD:
162+
raise ValueError("attention head quantization cannot be applied to weights")
163+
161164
assert False, f"Unknown strategy {args.strategy}"
162165

163166

@@ -182,6 +185,9 @@ def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationA
182185
if args.strategy == QuantizationStrategy.BLOCK:
183186
raise ValueError("Block quantization cannot be applied to activations")
184187

188+
if args.strategy == QuantizationStrategy.ATTN_HEAD:
189+
raise ValueError("attention head quantization cannot be applied to linear acts")
190+
185191
assert False, f"Unknown strategy {args.strategy}"
186192

187193

@@ -203,4 +209,8 @@ def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationAr
203209
if args.strategy == QuantizationStrategy.BLOCK:
204210
raise ValueError("Block quantization cannot be applied to attention")
205211

212+
if args.strategy == QuantizationStrategy.ATTN_HEAD:
213+
# (batch_size * seq_len, num_heads, 1, head_dim)
214+
return value.flatten(0, 1).unsqueeze(-2)
215+
206216
assert False, f"Unknown strategy {args.strategy}"

tests/test_quantization/lifecycle/test_static_lifecycle.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,25 @@ class MockAttention(torch.nn.Module):
302302
# group is not supported
303303
# tensor group is not supported
304304
# block is not supported
305+
(
306+
QuantizationArgs(
307+
num_bits=4,
308+
type="int",
309+
symmetric=True,
310+
strategy="attn_head",
311+
),
312+
torch.tensor([[0], [3]]),
313+
torch.tensor([[8], [11]]),
314+
torch.tensor(
315+
[
316+
[
317+
[[0.0000, 1.0703, 2.1406], [2.9375, 4.4062, 4.4062]],
318+
[[6.4375, 7.5000, 7.5000], [8.8125, 10.2500, 10.2500]],
319+
]
320+
]
321+
),
322+
0.16,
323+
),
305324
],
306325
)
307326
def test_static_attention_quantization(

0 commit comments

Comments
 (0)