Skip to content

Commit 281f1c6

Browse files
committed
tests: rely on apply_quantization_config to init scale/zero-point; remove manual creation
1 parent bd1d083 commit 281f1c6

File tree

1 file changed

+19
-36
lines changed

1 file changed

+19
-36
lines changed

tests/test_compressors/quantized_compressors/test_asymmetric_decompression.py

Lines changed: 19 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -96,22 +96,7 @@ def test_end_to_end_asymmetric_quantization(strategy, group_size):
9696
)
9797
apply_quantization_config(model, quant_config)
9898

99-
for name, module in model.named_modules():
100-
if isinstance(module, Linear):
101-
weight = module.weight
102-
if strategy == QuantizationStrategy.CHANNEL:
103-
scale_shape = (weight.shape[0], 1)
104-
else:
105-
scale_shape = (weight.shape[0], weight.shape[1] // group_size)
106-
107-
module.weight_scale = torch.nn.Parameter(
108-
torch.rand(scale_shape) * 0.1,
109-
requires_grad=False
110-
)
111-
module.weight_zero_point = torch.nn.Parameter(
112-
torch.randint(-8, 8, scale_shape, dtype=torch.int8),
113-
requires_grad=False
114-
)
99+
115100

116101
compressor = PackedQuantizationCompressor(config=quant_config)
117102
quantized_modules_to_scheme = {
@@ -168,34 +153,32 @@ def test_asymmetric_quantization_accuracy(num_bits):
168153
"""
169154
with tempfile.TemporaryDirectory() as tmp_dir:
170155
tmp_path = Path(tmp_dir)
171-
156+
172157
shape = (256, 512)
173-
weights = torch.randn(shape) + 2.0
174-
158+
biased_weights = torch.randn(shape) + 2.0
159+
175160
quant_config = create_asymmetric_quant_config(
176161
num_bits=num_bits,
177162
strategy=QuantizationStrategy.GROUP,
178-
group_size=128
163+
group_size=128,
179164
)
180-
181-
group_size = 128
182-
num_groups = shape[1] // group_size
183-
scale_shape = (shape[0], num_groups)
184-
185-
scales = torch.rand(scale_shape) * 0.1
186-
zero_points = torch.randint(-2**(num_bits-1), 2**(num_bits-1), scale_shape, dtype=torch.int8)
187-
188-
state_dict = {
189-
"layer.weight": weights,
190-
"layer.weight_scale": scales,
191-
"layer.weight_zero_point": zero_points,
192-
}
193-
165+
166+
class SingleLayer(Module):
167+
def __init__(self):
168+
super().__init__()
169+
self.layer = Linear(shape[1], shape[0], bias=False)
170+
171+
model = SingleLayer()
172+
apply_quantization_config(model, quant_config)
173+
174+
with torch.no_grad():
175+
model.layer.weight.copy_(biased_weights)
176+
194177
compressor = PackedQuantizationCompressor(config=quant_config)
195178
quantized_modules_to_scheme = {"layer": quant_config.config_groups["group_1"]}
196-
179+
197180
compressed_state_dict = compressor.compress(
198-
state_dict.copy(), names_to_scheme=quantized_modules_to_scheme
181+
model.state_dict().copy(), names_to_scheme=quantized_modules_to_scheme
199182
)
200183

201184
save_file(compressed_state_dict, tmp_path / "model.safetensors")

0 commit comments

Comments
 (0)