Skip to content

Commit 233890f

Browse files
committed
add tests
Signed-off-by: Kyle Sayers <[email protected]>
1 parent db3f9b7 commit 233890f

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

tests/mock_observer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs)
110110
.unsqueeze(0)
111111
)
112112

113+
if args.strategy == QuantizationStrategy.ATTN_HEAD:
114+
raise ValueError("attention head quantization cannot be applied to weights")
115+
113116
assert False, f"Unknown strategy {args.strategy}"
114117

115118

@@ -134,6 +137,9 @@ def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationA
134137
if args.strategy == QuantizationStrategy.BLOCK:
135138
raise ValueError("Block quantization cannot be applied to activations")
136139

140+
if args.strategy == QuantizationStrategy.ATTN_HEAD:
141+
raise ValueError("attention head quantization cannot be applied to linear acts")
142+
137143
assert False, f"Unknown strategy {args.strategy}"
138144

139145

@@ -155,4 +161,8 @@ def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationAr
155161
if args.strategy == QuantizationStrategy.BLOCK:
156162
raise ValueError("Block quantization cannot be applied to attention")
157163

164+
if args.strategy == QuantizationStrategy.ATTN_HEAD:
165+
# (batch_size * seq_len, num_heads, 1, head_dim)
166+
return value.flatten(0, 1).unsqueeze(-2)
167+
158168
assert False, f"Unknown strategy {args.strategy}"

tests/test_quantization/lifecycle/test_static_lifecycle.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,25 @@ class MockAttention(torch.nn.Module):
303303
# group is not supported
304304
# tensor group is not supported
305305
# block is not supported
306+
(
307+
QuantizationArgs(
308+
num_bits=4,
309+
type="int",
310+
symmetric=True,
311+
strategy="attn_head",
312+
),
313+
torch.tensor([[0], [3]]),
314+
torch.tensor([[8], [11]]),
315+
torch.tensor(
316+
[
317+
[
318+
[[0.0000, 1.0703, 2.1406], [2.9375, 4.4062, 4.4062]],
319+
[[6.4375, 7.5000, 7.5000], [8.8125, 10.2500, 10.2500]],
320+
]
321+
]
322+
),
323+
0.16,
324+
),
306325
],
307326
)
308327
def test_static_attention_quantization(

0 commit comments

Comments
 (0)