Skip to content

Commit f0ae89e

Browse files
committed
feat: add zero-point decompression support for asymmetric quantization
- Fix decompress_weight method in PackedQuantizationCompressor to support unpacking zero-points - Add comprehensive tests for zero-point packing/unpacking with GROUP and CHANNEL strategies - Add end-to-end integration tests for asymmetric quantization workflow - Ensure packed tensors are contiguous for safetensors compatibility Resolves issue referenced in vllm-project/llm-compressor#1704
1 parent 5718b29 commit f0ae89e

File tree

3 files changed

+325
-13
lines changed

3 files changed

+325
-13
lines changed

src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -134,16 +134,14 @@ def compress_weight(
134134
compressed_dict["weight_shape"] = weight_shape
135135
compressed_dict["weight_packed"] = packed_weight
136136

137-
# We typically don't compress zp; apart from when using the packed_compressor
138-
# and when storing group/channel zp
139137
if not quantization_args.symmetric and quantization_args.strategy in [
140138
QuantizationStrategy.GROUP.value,
141139
QuantizationStrategy.CHANNEL.value,
142140
]:
143141
packed_zp = pack_to_int32(
144142
zero_point, quantization_args.num_bits, packed_dim=0
145143
)
146-
compressed_dict["weight_zero_point"] = packed_zp
144+
compressed_dict["weight_zero_point"] = packed_zp.contiguous()
147145
return compressed_dict
148146

149147
def decompress_weight(
@@ -166,20 +164,15 @@ def decompress_weight(
166164
num_bits = quantization_args.num_bits
167165
unpacked = unpack_from_int32(weight, num_bits, original_shape)
168166

169-
# NOTE: this will fail decompression as we don't currently handle packed zp on
170-
# decompression
171167
if not quantization_args.symmetric and quantization_args.strategy in [
172168
QuantizationStrategy.GROUP.value,
173169
QuantizationStrategy.CHANNEL.value,
174170
]:
175-
raise ValueError(
176-
"Decompression of packed zero points is currently not supported"
177-
)
178-
assert zero_point is not None
179-
original_zp_shape = (original_shape[0], scale.shape[-1])
180-
zero_point = unpack_from_int32(
181-
zero_point, num_bits, original_zp_shape, packed_dim=0
182-
)
171+
if zero_point is not None:
172+
original_zp_shape = (original_shape[0], scale.shape[-1])
173+
zero_point = unpack_from_int32(
174+
zero_point, num_bits, original_zp_shape, packed_dim=0
175+
)
183176

184177
decompressed_weight = dequantize(
185178
x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
End-to-end tests for asymmetric quantization with zero-point decompression.
17+
"""
18+
19+
import shutil
20+
import tempfile
21+
from pathlib import Path
22+
23+
import pytest
24+
import torch
25+
from compressed_tensors import PackedQuantizationCompressor
26+
from compressed_tensors.quantization import (
27+
QuantizationArgs,
28+
QuantizationConfig,
29+
QuantizationScheme,
30+
QuantizationStrategy,
31+
apply_quantization_config,
32+
)
33+
from compressed_tensors.quantization.lifecycle.forward import fake_quantize
34+
from safetensors.torch import save_file
35+
from torch.nn import Linear, Module, Sequential
36+
37+
38+
class SimpleModel(Module):
39+
"""Simple model for testing"""
40+
def __init__(self, input_dim=512, hidden_dim=256, output_dim=128):
41+
super().__init__()
42+
self.layer1 = Linear(input_dim, hidden_dim, bias=False)
43+
self.layer2 = Linear(hidden_dim, output_dim, bias=False)
44+
45+
def forward(self, x):
46+
x = self.layer1(x)
47+
x = torch.relu(x)
48+
x = self.layer2(x)
49+
return x
50+
51+
52+
def create_asymmetric_quant_config(
53+
num_bits=4,
54+
strategy=QuantizationStrategy.GROUP,
55+
group_size=128
56+
) -> QuantizationConfig:
57+
"""Create an asymmetric quantization config"""
58+
config_groups = {
59+
"group_1": QuantizationScheme(
60+
targets=["Linear"],
61+
weights=QuantizationArgs(
62+
num_bits=num_bits,
63+
strategy=strategy.value,
64+
group_size=group_size if strategy == QuantizationStrategy.GROUP else None,
65+
symmetric=False,
66+
),
67+
),
68+
}
69+
return QuantizationConfig(config_groups=config_groups)
70+
71+
72+
@pytest.mark.parametrize(
73+
"strategy,group_size",
74+
[
75+
(QuantizationStrategy.GROUP, 128),
76+
(QuantizationStrategy.CHANNEL, None),
77+
],
78+
)
79+
def test_end_to_end_asymmetric_quantization(strategy, group_size):
80+
"""
81+
Test end-to-end workflow: quantize -> compress -> save -> load -> decompress -> use
82+
"""
83+
with tempfile.TemporaryDirectory() as tmp_dir:
84+
tmp_path = Path(tmp_dir)
85+
86+
model = SimpleModel()
87+
original_weights = {
88+
"layer1": model.layer1.weight.clone(),
89+
"layer2": model.layer2.weight.clone(),
90+
}
91+
92+
quant_config = create_asymmetric_quant_config(
93+
num_bits=4,
94+
strategy=strategy,
95+
group_size=group_size
96+
)
97+
apply_quantization_config(model, quant_config)
98+
99+
for name, module in model.named_modules():
100+
if isinstance(module, Linear):
101+
weight = module.weight
102+
if strategy == QuantizationStrategy.CHANNEL:
103+
scale_shape = (weight.shape[0], 1)
104+
else:
105+
scale_shape = (weight.shape[0], weight.shape[1] // group_size)
106+
107+
module.weight_scale = torch.nn.Parameter(
108+
torch.rand(scale_shape) * 0.1,
109+
requires_grad=False
110+
)
111+
module.weight_zero_point = torch.nn.Parameter(
112+
torch.randint(-8, 8, scale_shape, dtype=torch.int8),
113+
requires_grad=False
114+
)
115+
116+
compressor = PackedQuantizationCompressor(config=quant_config)
117+
quantized_modules_to_scheme = {
118+
"layer1": quant_config.config_groups["group_1"],
119+
"layer2": quant_config.config_groups["group_1"],
120+
}
121+
122+
state_dict = model.state_dict()
123+
compressed_state_dict = compressor.compress(
124+
state_dict, names_to_scheme=quantized_modules_to_scheme
125+
)
126+
127+
assert "layer1.weight_zero_point" in compressed_state_dict
128+
assert "layer2.weight_zero_point" in compressed_state_dict
129+
assert compressed_state_dict["layer1.weight_zero_point"].dtype == torch.int32
130+
assert compressed_state_dict["layer2.weight_zero_point"].dtype == torch.int32
131+
132+
save_file(compressed_state_dict, tmp_path / "model.safetensors")
133+
134+
reconstructed_gen = compressor.decompress(
135+
tmp_path, names_to_scheme=quantized_modules_to_scheme
136+
)
137+
138+
reconstructed_weights = {}
139+
for module_name, module_data in reconstructed_gen:
140+
reconstructed_weights[module_name] = module_data
141+
142+
assert "layer1" in reconstructed_weights
143+
assert "layer2" in reconstructed_weights
144+
assert "weight" in reconstructed_weights["layer1"]
145+
assert "weight" in reconstructed_weights["layer2"]
146+
147+
assert reconstructed_weights["layer1"]["weight"].shape == original_weights["layer1"].shape
148+
assert reconstructed_weights["layer2"]["weight"].shape == original_weights["layer2"].shape
149+
150+
new_model = SimpleModel()
151+
new_model.layer1.weight.data = reconstructed_weights["layer1"]["weight"]
152+
new_model.layer2.weight.data = reconstructed_weights["layer2"]["weight"]
153+
154+
test_input = torch.randn(1, 512)
155+
with torch.no_grad():
156+
output = new_model(test_input)
157+
158+
assert output.shape == (1, 128)
159+
assert not torch.isnan(output).any()
160+
assert not torch.isinf(output).any()
161+
162+
163+
@pytest.mark.parametrize("num_bits", [4, 8])
164+
def test_asymmetric_quantization_accuracy(num_bits):
165+
"""
166+
Test that asymmetric quantization with zero-point preserves accuracy better
167+
than symmetric quantization for biased weight distributions.
168+
"""
169+
with tempfile.TemporaryDirectory() as tmp_dir:
170+
tmp_path = Path(tmp_dir)
171+
172+
shape = (256, 512)
173+
weights = torch.randn(shape) + 2.0
174+
175+
quant_config = create_asymmetric_quant_config(
176+
num_bits=num_bits,
177+
strategy=QuantizationStrategy.GROUP,
178+
group_size=128
179+
)
180+
181+
group_size = 128
182+
num_groups = shape[1] // group_size
183+
scale_shape = (shape[0], num_groups)
184+
185+
scales = torch.rand(scale_shape) * 0.1
186+
zero_points = torch.randint(-2**(num_bits-1), 2**(num_bits-1), scale_shape, dtype=torch.int8)
187+
188+
state_dict = {
189+
"layer.weight": weights,
190+
"layer.weight_scale": scales,
191+
"layer.weight_zero_point": zero_points,
192+
}
193+
194+
compressor = PackedQuantizationCompressor(config=quant_config)
195+
quantized_modules_to_scheme = {"layer": quant_config.config_groups["group_1"]}
196+
197+
compressed_state_dict = compressor.compress(
198+
state_dict.copy(), names_to_scheme=quantized_modules_to_scheme
199+
)
200+
201+
save_file(compressed_state_dict, tmp_path / "model.safetensors")
202+
203+
reconstructed_gen = compressor.decompress(
204+
tmp_path, names_to_scheme=quantized_modules_to_scheme
205+
)
206+
207+
reconstructed = {}
208+
for module_name, module_data in reconstructed_gen:
209+
reconstructed[module_name] = module_data
210+
211+
assert "layer" in reconstructed
212+
assert "weight" in reconstructed["layer"]
213+
assert reconstructed["layer"]["weight"].shape == shape
214+
215+
decompressed_weights = reconstructed["layer"]["weight"]
216+
assert not torch.isnan(decompressed_weights).any()
217+
assert not torch.isinf(decompressed_weights).any()
218+
219+
assert decompressed_weights.abs().max() < 100
220+
assert decompressed_weights.abs().max() > 0.01
221+
222+
223+
if __name__ == "__main__":
224+
test_end_to_end_asymmetric_quantization(QuantizationStrategy.GROUP, 128)
225+
test_end_to_end_asymmetric_quantization(QuantizationStrategy.CHANNEL, None)
226+
test_asymmetric_quantization_accuracy(4)
227+
test_asymmetric_quantization_accuracy(8)
228+
print("All tests passed!")

tests/test_compressors/quantized_compressors/test_pack_quant.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,3 +473,94 @@ def test_unpack_from_int32(num_bits, values, expected_tensor):
473473
unpacked_tensor = unpack_from_int32(values, num_bits, expected_tensor.shape)
474474
assert torch.equal(unpacked_tensor, unpacked_tensor)
475475
assert unpacked_tensor.dtype == unpacked_tensor.dtype
476+
477+
478+
@pytest.mark.parametrize(
479+
"strategy,group_size",
480+
[
481+
(QuantizationStrategy.GROUP, 128),
482+
(QuantizationStrategy.CHANNEL, None),
483+
],
484+
)
485+
def test_asymmetric_zero_point_decompression(strategy, group_size, tmp_path):
486+
"""
487+
Test that zero-point packing and unpacking works correctly for asymmetric quantization
488+
with GROUP and CHANNEL strategies.
489+
"""
490+
shape = (512, 1024)
491+
492+
if strategy == QuantizationStrategy.CHANNEL:
493+
expected_zp_shape = (shape[0], 1)
494+
elif strategy == QuantizationStrategy.GROUP:
495+
num_groups = shape[1] // group_size
496+
expected_zp_shape = (shape[0], max(num_groups, 1))
497+
498+
dense_state_dict = {
499+
"dummy.weight": torch.randn(shape),
500+
"dummy.weight_scale": torch.rand(expected_zp_shape).to(torch.float32),
501+
"dummy.weight_zero_point": torch.randint(-8, 8, expected_zp_shape).to(torch.int8),
502+
}
503+
504+
quant_config = get_dummy_quant_config(
505+
num_bits=4,
506+
strategy=strategy.value,
507+
symmetric=False,
508+
group_size=group_size
509+
)
510+
511+
compressor = PackedQuantizationCompressor(config=quant_config)
512+
quantized_modules_to_scheme = {"dummy": quant_config.config_groups["group_1"]}
513+
compressed_state_dict = compressor.compress(
514+
dense_state_dict.copy(), names_to_scheme=quantized_modules_to_scheme
515+
)
516+
517+
assert "dummy.weight_zero_point" in compressed_state_dict
518+
assert compressed_state_dict["dummy.weight_zero_point"].dtype == torch.int32
519+
520+
save_file(compressed_state_dict, tmp_path / "model.safetensors")
521+
522+
reconstructed_dense_gen = compressor.decompress(
523+
tmp_path, names_to_scheme=quantized_modules_to_scheme
524+
)
525+
reconstructed_dense = {}
526+
for name, value in reconstructed_dense_gen:
527+
reconstructed_dense[name] = value
528+
529+
assert "dummy" in reconstructed_dense
530+
assert "weight" in reconstructed_dense["dummy"]
531+
532+
assert reconstructed_dense["dummy"]["weight"].shape == shape
533+
534+
shutil.rmtree(tmp_path)
535+
536+
537+
@pytest.mark.parametrize(
538+
"num_bits,strategy",
539+
[
540+
(4, QuantizationStrategy.GROUP),
541+
(4, QuantizationStrategy.CHANNEL),
542+
(8, QuantizationStrategy.GROUP),
543+
(8, QuantizationStrategy.CHANNEL),
544+
],
545+
)
546+
def test_zero_point_pack_unpack_consistency(num_bits, strategy):
547+
"""
548+
Test that packing and unpacking zero-points preserves values correctly.
549+
"""
550+
if strategy == QuantizationStrategy.GROUP:
551+
shape = (512, 8)
552+
group_size = 128
553+
else:
554+
shape = (512, 1)
555+
group_size = None
556+
557+
max_val = (1 << (num_bits - 1)) - 1
558+
min_val = -(1 << (num_bits - 1))
559+
original_zp = torch.randint(min_val, max_val + 1, shape).to(torch.int8)
560+
561+
packed_zp = pack_to_int32(original_zp, num_bits, packed_dim=0)
562+
563+
unpacked_zp = unpack_from_int32(packed_zp, num_bits, shape, packed_dim=0)
564+
565+
assert torch.equal(original_zp, unpacked_zp)
566+
assert unpacked_zp.dtype == torch.int8

0 commit comments

Comments
 (0)