Skip to content

Commit 975cb22

Browse files
authored
[Bugfix] Update expected shape for per token strategy (#210)
* update expected shape for per token strategy * add tests * wip * add helpers test Signed-off-by: Kyle Sayers <[email protected]> * remove breakpoint Signed-off-by: Kyle Sayers <[email protected]> * remove unnecessary arg --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent 1fa514a commit 975cb22

File tree

5 files changed

+159
-15
lines changed

5 files changed

+159
-15
lines changed

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,10 @@ def _initialize_scale_zero_point(
174174
device = get_execution_device(module)
175175

176176
# infer expected scale/zero point shape
177-
expected_shape = 1 # per tensor
177+
if quantization_args.strategy == QuantizationStrategy.TOKEN:
178+
expected_shape = (1, 1)
179+
else:
180+
expected_shape = 1
178181

179182
if base_name == "weight" and weight_shape is not None:
180183
if quantization_args.strategy == QuantizationStrategy.CHANNEL:

tests/conftest.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ def update_scale_zp(module: torch.nn.Module, base_name: str, value: torch.Tensor
4444
min_val = torch.amin(value, dim=dim, keepdims=True)
4545
max_val = torch.amax(value, dim=dim, keepdims=True)
4646
scale, zp = calculate_qparams(min_val, max_val, args)
47-
scale = scale.reshape((1, 1))
48-
zp = zp.reshape((1, 1))
4947
update_parameter_data(module, scale, f"{base_name}_scale")
5048
update_parameter_data(module, zp, f"{base_name}_zero_point")
5149

tests/test_quantization/lifecycle/test_initialize.py

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,25 @@
1414

1515

1616
import pytest
17+
from compressed_tensors.quantization import (
18+
ActivationOrdering,
19+
QuantizationArgs,
20+
QuantizationScheme,
21+
QuantizationStatus,
22+
QuantizationStrategy,
23+
)
1724
from compressed_tensors.quantization.lifecycle.initialize import (
1825
initialize_module_for_quantization,
1926
)
20-
from compressed_tensors.quantization.quant_args import QuantizationArgs
21-
from compressed_tensors.quantization.quant_config import QuantizationStatus
2227
from torch.nn import Linear
2328

2429

2530
NUM_BITS = 8
31+
Q_PARAM_NAMES = {
32+
"input_activations": "input",
33+
"weights": "weight",
34+
"output_activations": "output",
35+
}
2636

2737

2838
@pytest.mark.parametrize(
@@ -77,3 +87,78 @@ def test_initialize_module_for_quantization(
7787
assert hasattr(layer, "quantization_status")
7888

7989
assert layer.quantization_status == QuantizationStatus.INITIALIZED
90+
91+
92+
@pytest.mark.parametrize(
93+
"weights,input_activations",
94+
[
95+
(
96+
QuantizationArgs(strategy="tensor"),
97+
QuantizationArgs(strategy="tensor"),
98+
),
99+
(
100+
QuantizationArgs(strategy="channel"),
101+
None,
102+
),
103+
(
104+
QuantizationArgs(strategy="group", group_size=2),
105+
None,
106+
),
107+
(
108+
QuantizationArgs(strategy="group", group_size=2, actorder="group"),
109+
None,
110+
),
111+
(
112+
QuantizationArgs(strategy="group", group_size=2, actorder="weight"),
113+
None,
114+
),
115+
(
116+
QuantizationArgs(strategy="block"),
117+
QuantizationArgs(strategy="block"),
118+
),
119+
(
120+
QuantizationArgs(strategy="token"),
121+
QuantizationArgs(strategy="token"),
122+
),
123+
],
124+
)
125+
def test_initialize_quantization_parameters(weights, input_activations):
126+
quantization_scheme = QuantizationScheme(
127+
targets=["*"],
128+
weights=weights,
129+
input_activations=input_activations,
130+
)
131+
layer = Linear(7, 8)
132+
initialize_module_for_quantization(layer, quantization_scheme)
133+
134+
for q_type in ("input_activations", "weights"):
135+
args = getattr(quantization_scheme, q_type)
136+
if args is None:
137+
continue
138+
q_param_name = Q_PARAM_NAMES[q_type]
139+
140+
# scale and zero point
141+
if args.strategy == QuantizationStrategy.TENSOR:
142+
expected_shape = (1,)
143+
144+
elif args.strategy == QuantizationStrategy.CHANNEL: # only weight
145+
expected_shape = (layer.weight.shape[0], 1)
146+
147+
elif args.strategy == QuantizationStrategy.GROUP: # only weight
148+
num_groups = layer.weight.shape[1] // args.group_size
149+
expected_shape = (layer.weight.shape[0], max(num_groups, 1))
150+
151+
elif args.strategy == QuantizationStrategy.BLOCK:
152+
expected_shape = (1,)
153+
154+
elif args.strategy == QuantizationStrategy.TOKEN:
155+
expected_shape = (1, 1)
156+
157+
assert getattr(layer, f"{q_param_name}_scale").shape == expected_shape
158+
assert getattr(layer, f"{q_param_name}_zero_point").shape == expected_shape
159+
160+
# g_idx
161+
if args.actorder == ActivationOrdering.GROUP:
162+
assert getattr(layer, f"{q_param_name}_g_idx").shape == (
163+
layer.weight.shape[1],
164+
)

tests/test_quantization/test_configs/test_strategies.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def test_channelwise(
6767
if input_symmetry is not None:
6868
mock_per_channel_calibration(model, base_name="input", value=inputs)
6969

70-
assert list(model.weight_scale.shape) == [model_shape[1], 1]
71-
assert list(model.weight_zero_point.shape) == [model_shape[1], 1]
70+
assert model.weight_scale.shape == (model_shape[1], 1)
71+
assert model.weight_zero_point.shape == (model_shape[1], 1)
7272

7373

7474
@torch.no_grad
@@ -97,14 +97,14 @@ def test_group(
9797
model, base_name="input", value=inputs, group_size=group_size
9898
)
9999

100-
assert list(model.weight_scale.shape) == [
100+
assert model.weight_scale.shape == (
101101
model_shape[1],
102102
int(model_shape[0] / group_size),
103-
]
104-
assert list(model.weight_zero_point.shape) == [
103+
)
104+
assert model.weight_zero_point.shape == (
105105
model_shape[1],
106106
int(model_shape[0] / group_size),
107-
]
107+
)
108108

109109

110110
@torch.no_grad
@@ -131,8 +131,8 @@ def test_token(
131131
mock_per_channel_calibration(model, base_name="weight", value=model.weight)
132132
mock_per_token_calibration(model, base_name="input", value=inputs)
133133

134-
assert list(model.input_scale.shape) == [1, 1]
135-
assert list(model.input_zero_point.shape) == [1, 1]
134+
assert model.input_scale.shape == (1, 1)
135+
assert model.input_zero_point.shape == (1, 1)
136136

137-
assert list(model.weight_scale.shape) == [256, 1]
138-
assert list(model.weight_zero_point.shape) == [256, 1]
137+
assert model.weight_scale.shape == (256, 1)
138+
assert model.weight_zero_point.shape == (256, 1)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
import torch
17+
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
18+
from compressed_tensors.quantization.utils import calculate_qparams
19+
20+
21+
@pytest.mark.parametrize(
22+
"keepdims,strategy,exp_shape",
23+
[
24+
(
25+
False,
26+
QuantizationStrategy.TENSOR,
27+
torch.Size(
28+
[
29+
1,
30+
]
31+
),
32+
),
33+
(True, QuantizationStrategy.CHANNEL, torch.Size([1, 1])),
34+
(True, QuantizationStrategy.GROUP, torch.Size([1, 1])),
35+
(
36+
False,
37+
QuantizationStrategy.BLOCK,
38+
torch.Size(
39+
[
40+
1,
41+
]
42+
),
43+
),
44+
(True, QuantizationStrategy.TOKEN, torch.Size([1, 1])),
45+
],
46+
)
47+
def test_calculate_qparams(keepdims, strategy, exp_shape):
48+
value = torch.randn(14, 5)
49+
min_val = torch.amin(value, dim=tuple(), keepdims=keepdims)
50+
max_val = torch.amax(value, dim=tuple(), keepdims=keepdims)
51+
52+
if strategy == QuantizationStrategy.GROUP:
53+
args = QuantizationArgs(strategy=strategy, group_size=2)
54+
else:
55+
args = QuantizationArgs(strategy=strategy)
56+
scale, zp = calculate_qparams(min_val, max_val, args)
57+
assert scale.shape == exp_shape
58+
assert zp.shape == exp_shape

0 commit comments

Comments
 (0)