Skip to content

Commit 2241c5b

Browse files
committed
fix test
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 15450a1 commit 2241c5b

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

tests/test_quantization/test_utils/test_helpers.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,36 +32,42 @@
3232
[
3333
(
3434
False,
35-
QuantizationStrategy.TENSOR,
35+
"tensor",
3636
torch.Size(
3737
[
3838
1,
3939
]
4040
),
4141
),
42-
(True, QuantizationStrategy.CHANNEL, torch.Size([1, 1])),
43-
(True, QuantizationStrategy.GROUP, torch.Size([1, 1])),
42+
(True, "channel", torch.Size([1, 1])),
43+
(True, "group", torch.Size([1, 1])),
4444
(
4545
False,
46-
QuantizationStrategy.BLOCK,
46+
"block",
4747
torch.Size(
4848
[
4949
1,
5050
]
5151
),
5252
),
53-
(True, QuantizationStrategy.TOKEN, torch.Size([1, 1])),
53+
(True, "token", torch.Size([1, 1])),
5454
],
5555
)
5656
def test_calculate_qparams(keepdims, strategy, exp_shape):
57-
value = torch.randn(14, 5)
57+
value = torch.empty(5, 6)
5858
min_val = torch.amin(value, dim=tuple(), keepdims=keepdims)
5959
max_val = torch.amax(value, dim=tuple(), keepdims=keepdims)
6060

6161
if strategy == QuantizationStrategy.GROUP:
6262
args = QuantizationArgs(strategy=strategy, group_size=2)
63+
elif strategy == QuantizationStrategy.BLOCK:
64+
args = QuantizationArgs(strategy=strategy, block_structure=[1, 3])
6365
else:
64-
args = QuantizationArgs(strategy=strategy)
66+
args = QuantizationArgs(
67+
strategy=strategy,
68+
group_size=(2 if strategy == "group" else None),
69+
block_structure=([1, 3] if strategy == "block" else None),
70+
)
6571
scale, zp = calculate_qparams(min_val, max_val, args)
6672
assert scale.shape == exp_shape
6773
assert zp.shape == exp_shape

0 commit comments

Comments
 (0)