|
32 | 32 | [
|
33 | 33 | (
|
34 | 34 | False,
|
35 |
| - QuantizationStrategy.TENSOR, |
| 35 | + "tensor", |
36 | 36 | torch.Size(
|
37 | 37 | [
|
38 | 38 | 1,
|
39 | 39 | ]
|
40 | 40 | ),
|
41 | 41 | ),
|
42 |
| - (True, QuantizationStrategy.CHANNEL, torch.Size([1, 1])), |
43 |
| - (True, QuantizationStrategy.GROUP, torch.Size([1, 1])), |
| 42 | + (True, "channel", torch.Size([1, 1])), |
| 43 | + (True, "group", torch.Size([1, 1])), |
44 | 44 | (
|
45 | 45 | False,
|
46 |
| - QuantizationStrategy.BLOCK, |
| 46 | + "block", |
47 | 47 | torch.Size(
|
48 | 48 | [
|
49 | 49 | 1,
|
50 | 50 | ]
|
51 | 51 | ),
|
52 | 52 | ),
|
53 |
| - (True, QuantizationStrategy.TOKEN, torch.Size([1, 1])), |
| 53 | + (True, "token", torch.Size([1, 1])), |
54 | 54 | ],
|
55 | 55 | )
|
56 | 56 | def test_calculate_qparams(keepdims, strategy, exp_shape):
|
57 |
| - value = torch.randn(14, 5) |
| 57 | + value = torch.empty(5, 6) |
58 | 58 | min_val = torch.amin(value, dim=tuple(), keepdims=keepdims)
|
59 | 59 | max_val = torch.amax(value, dim=tuple(), keepdims=keepdims)
|
60 | 60 |
|
61 | 61 | if strategy == QuantizationStrategy.GROUP:
|
62 | 62 | args = QuantizationArgs(strategy=strategy, group_size=2)
|
| 63 | + elif strategy == QuantizationStrategy.BLOCK: |
| 64 | + args = QuantizationArgs(strategy=strategy, block_structure=[1, 3]) |
63 | 65 | else:
|
64 |
| - args = QuantizationArgs(strategy=strategy) |
| 66 | + args = QuantizationArgs( |
| 67 | + strategy=strategy, |
| 68 | + group_size=(2 if strategy == "group" else None), |
| 69 | + block_structure=([1, 3] if strategy == "block" else None), |
| 70 | + ) |
65 | 71 | scale, zp = calculate_qparams(min_val, max_val, args)
|
66 | 72 | assert scale.shape == exp_shape
|
67 | 73 | assert zp.shape == exp_shape
|
|
0 commit comments