diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index e2ce3d24b..07da50c7f 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -134,8 +134,6 @@ def compress_weight( compressed_dict["weight_shape"] = weight_shape compressed_dict["weight_packed"] = packed_weight - # We typically don't compress zp; apart from when using the packed_compressor - # and when storing group/channel zp if not quantization_args.symmetric and quantization_args.strategy in [ QuantizationStrategy.GROUP.value, QuantizationStrategy.CHANNEL.value, @@ -143,7 +141,7 @@ def compress_weight( packed_zp = pack_to_int32( zero_point, quantization_args.num_bits, packed_dim=0 ) - compressed_dict["weight_zero_point"] = packed_zp + compressed_dict["weight_zero_point"] = packed_zp.contiguous() return compressed_dict def decompress_weight( @@ -166,16 +164,13 @@ def decompress_weight( num_bits = quantization_args.num_bits unpacked = unpack_from_int32(weight, num_bits, original_shape) - # NOTE: this will fail decompression as we don't currently handle packed zp on - # decompression if not quantization_args.symmetric and quantization_args.strategy in [ QuantizationStrategy.GROUP.value, QuantizationStrategy.CHANNEL.value, ]: - raise ValueError( - "Decompression of packed zero points is currently not supported" - ) - assert zero_point is not None + assert ( + zero_point is not None + ), "Asymmetric quantization requires zero-point values" original_zp_shape = (original_shape[0], scale.shape[-1]) zero_point = unpack_from_int32( zero_point, num_bits, original_zp_shape, packed_dim=0 diff --git a/tests/test_compressors/quantized_compressors/test_pack_quant.py b/tests/test_compressors/quantized_compressors/test_pack_quant.py index 00d612756..5cf6da379 100644 --- a/tests/test_compressors/quantized_compressors/test_pack_quant.py +++ b/tests/test_compressors/quantized_compressors/test_pack_quant.py @@ -15,6 +15,7 @@ import math import shutil +import tempfile from collections import OrderedDict import pytest @@ -170,12 +171,13 @@ def test_reload_match(tmp_path, num_bits): ) save_file(compressed_state_dict, tmp_path / "model.safetensors") - reconstructed_dense_gen = compressor.decompress( - tmp_path, names_to_scheme=quantized_modules_to_scheme - ) reconstructed_dense = {} - for name, value in reconstructed_dense_gen: - reconstructed_dense[name] = value + with tempfile.TemporaryDirectory() as _tmp: + reconstructed_dense_gen = compressor.decompress( + tmp_path, names_to_scheme=quantized_modules_to_scheme + ) + for name, value in reconstructed_dense_gen: + reconstructed_dense[name] = value fake_quant_dummy = fake_quantize( dense_state_dict["dummy.weight"], @@ -473,3 +475,94 @@ def test_unpack_from_int32(num_bits, values, expected_tensor): unpacked_tensor = unpack_from_int32(values, num_bits, expected_tensor.shape) assert torch.equal(unpacked_tensor, unpacked_tensor) assert unpacked_tensor.dtype == unpacked_tensor.dtype + + +@pytest.mark.parametrize( + "strategy,group_size", + [ + (QuantizationStrategy.GROUP, 128), + (QuantizationStrategy.CHANNEL, None), + ], +) +def test_asymmetric_zero_point_decompression(strategy, group_size, tmp_path): + """ + Test that zero-point packing and unpacking works correctly for asymmetric quantization + with GROUP and CHANNEL strategies. + """ + shape = (512, 1024) + + if strategy == QuantizationStrategy.CHANNEL: + expected_zp_shape = (shape[0], 1) + elif strategy == QuantizationStrategy.GROUP: + num_groups = shape[1] // group_size + expected_zp_shape = (shape[0], max(num_groups, 1)) + + dense_state_dict = { + "dummy.weight": torch.randn(shape), + "dummy.weight_scale": torch.rand(expected_zp_shape).to(torch.float32), + "dummy.weight_zero_point": torch.randint(-8, 8, expected_zp_shape).to(torch.int8), + } + + quant_config = get_dummy_quant_config( + num_bits=4, + strategy=strategy.value, + symmetric=False, + group_size=group_size + ) + + compressor = PackedQuantizationCompressor(config=quant_config) + quantized_modules_to_scheme = {"dummy": quant_config.config_groups["group_1"]} + compressed_state_dict = compressor.compress( + dense_state_dict.copy(), names_to_scheme=quantized_modules_to_scheme + ) + + assert "dummy.weight_zero_point" in compressed_state_dict + assert compressed_state_dict["dummy.weight_zero_point"].dtype == torch.int32 + + save_file(compressed_state_dict, tmp_path / "model.safetensors") + + reconstructed_dense_gen = compressor.decompress( + tmp_path, names_to_scheme=quantized_modules_to_scheme + ) + reconstructed_dense = {} + for name, value in reconstructed_dense_gen: + reconstructed_dense[name] = value + + assert "dummy" in reconstructed_dense + assert "weight" in reconstructed_dense["dummy"] + + assert reconstructed_dense["dummy"]["weight"].shape == shape + + shutil.rmtree(tmp_path) + + +@pytest.mark.parametrize( + "num_bits,strategy", + [ + (4, QuantizationStrategy.GROUP), + (4, QuantizationStrategy.CHANNEL), + (8, QuantizationStrategy.GROUP), + (8, QuantizationStrategy.CHANNEL), + ], +) +def test_zero_point_pack_unpack_consistency(num_bits, strategy): + """ + Test that packing and unpacking zero-points preserves values correctly. + """ + if strategy == QuantizationStrategy.GROUP: + shape = (512, 8) + group_size = 128 + else: + shape = (512, 1) + group_size = None + + max_val = (1 << (num_bits - 1)) - 1 + min_val = -(1 << (num_bits - 1)) + original_zp = torch.randint(min_val, max_val + 1, shape).to(torch.int8) + + packed_zp = pack_to_int32(original_zp, num_bits, packed_dim=0) + + unpacked_zp = unpack_from_int32(packed_zp, num_bits, shape, packed_dim=0) + + assert torch.equal(original_zp, unpacked_zp) + assert unpacked_zp.dtype == torch.int8 diff --git a/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py b/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py new file mode 100644 index 000000000..62ece296f --- /dev/null +++ b/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py @@ -0,0 +1,224 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +End-to-end tests for asymmetric quantization with zero-point decompression. +""" + +import shutil +import tempfile +from pathlib import Path + +import pytest +import torch +from compressed_tensors import PackedQuantizationCompressor +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationConfig, + QuantizationScheme, + QuantizationStrategy, + apply_quantization_config, +) +from compressed_tensors.config import CompressionFormat +from compressed_tensors.quantization.lifecycle.forward import fake_quantize +from safetensors.torch import save_file +from compressed_tensors.compressors.model_compressors.model_compressor import ( + ModelCompressor, +) +from torch.nn import Linear, Module, Sequential + + +class SimpleModel(Module): + """Simple model for testing""" + def __init__(self, input_dim=512, hidden_dim=256, output_dim=128): + super().__init__() + self.layer1 = Linear(input_dim, hidden_dim, bias=False) + self.layer2 = Linear(hidden_dim, output_dim, bias=False) + + def forward(self, x): + x = self.layer1(x) + x = torch.relu(x) + x = self.layer2(x) + return x + + +def create_asymmetric_quant_config( + num_bits=4, + strategy=QuantizationStrategy.GROUP, + group_size=128 +) -> QuantizationConfig: + """Create an asymmetric quantization config""" + config_groups = { + "group_1": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=num_bits, + strategy=strategy.value, + group_size=group_size if strategy == QuantizationStrategy.GROUP else None, + symmetric=False, + ), + ), + } + return QuantizationConfig(config_groups=config_groups) + + +@pytest.mark.parametrize( + "strategy,group_size", + [ + (QuantizationStrategy.GROUP, 128), + (QuantizationStrategy.CHANNEL, None), + ], +) +def test_end_to_end_asymmetric_quantization( + strategy, + group_size, + mock_per_group_calibration, + mock_per_channel_calibration, +): + """ + Test end-to-end workflow: quantize -> compress -> save -> load -> decompress -> use + """ + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + model = SimpleModel() + original_weights = { + "layer1": model.layer1.weight.detach().clone(), + "layer2": model.layer2.weight.detach().clone(), + } + + quant_config = create_asymmetric_quant_config( + num_bits=4, + strategy=strategy, + group_size=group_size + ) + # Set pack-quantized format for ModelCompressor usage + quant_config.format = CompressionFormat.pack_quantized.value + apply_quantization_config(model, quant_config) + + if strategy == QuantizationStrategy.GROUP: + mock_per_group_calibration(model.layer1, "weight", model.layer1.weight, group_size) + mock_per_group_calibration(model.layer2, "weight", model.layer2.weight, group_size) + else: + mock_per_channel_calibration(model.layer1, "weight", model.layer1.weight) + mock_per_channel_calibration(model.layer2, "weight", model.layer2.weight) + + + + compressor = PackedQuantizationCompressor(config=quant_config) + quantized_modules_to_scheme = { + "layer1": quant_config.config_groups["group_1"], + "layer2": quant_config.config_groups["group_1"], + } + + state_dict = model.state_dict() + compressed_state_dict = compressor.compress( + state_dict, names_to_scheme=quantized_modules_to_scheme + ) + + assert "layer1.weight_zero_point" in compressed_state_dict + assert "layer2.weight_zero_point" in compressed_state_dict + assert compressed_state_dict["layer1.weight_zero_point"].dtype == torch.int32 + assert compressed_state_dict["layer2.weight_zero_point"].dtype == torch.int32 + + new_model = SimpleModel() + apply_quantization_config(new_model, quant_config) + + for module_name in ["layer1", "layer2"]: + module = getattr(new_model, module_name) + prefix = f"{module_name}." + for key, value in compressed_state_dict.items(): + if key.startswith(prefix): + param_name = key[len(prefix):] + if hasattr(module, param_name): + getattr(module, param_name).data = value.clone() + else: + module.register_parameter( + param_name, torch.nn.Parameter(value.clone(), requires_grad=False) + ) + + mc = ModelCompressor(quantization_config=quant_config) + mc.decompress_model(new_model) + + assert new_model.layer1.weight.shape == original_weights["layer1"].shape + assert new_model.layer2.weight.shape == original_weights["layer2"].shape + assert new_model.layer1.weight.dtype.is_floating_point + assert new_model.layer2.weight.dtype.is_floating_point + assert not torch.isnan(new_model.layer1.weight).any() + assert not torch.isnan(new_model.layer2.weight).any() + assert not torch.isinf(new_model.layer1.weight).any() + assert not torch.isinf(new_model.layer2.weight).any() + + +@pytest.mark.parametrize("num_bits", [4, 8]) +def test_asymmetric_quantization_accuracy(num_bits, mock_per_group_calibration): + """ + Test that asymmetric quantization with zero-point preserves accuracy better + than symmetric quantization for biased weight distributions. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + shape = (256, 512) + biased_weights = torch.randn(shape) + 2.0 + + quant_config = create_asymmetric_quant_config( + num_bits=num_bits, + strategy=QuantizationStrategy.GROUP, + group_size=128, + ) + quant_config.format = CompressionFormat.pack_quantized.value + + class SingleLayer(Module): + def __init__(self): + super().__init__() + self.layer = Linear(shape[1], shape[0], bias=False) + + model = SingleLayer() + apply_quantization_config(model, quant_config) + + with torch.no_grad(): + model.layer.weight.copy_(biased_weights) + mock_per_group_calibration(model.layer, "weight", model.layer.weight, 128) + + compressor = PackedQuantizationCompressor(config=quant_config) + quantized_modules_to_scheme = {"layer": quant_config.config_groups["group_1"]} + + compressed_state_dict = compressor.compress( + model.state_dict().copy(), names_to_scheme=quantized_modules_to_scheme + ) + + new_model = SingleLayer() + apply_quantization_config(new_model, quant_config) + + module = new_model.layer + for key, value in compressed_state_dict.items(): + if key.startswith("layer."): + param_name = key[len("layer."):] + if hasattr(module, param_name): + getattr(module, param_name).data = value.clone() + else: + module.register_parameter( + param_name, torch.nn.Parameter(value.clone(), requires_grad=False) + ) + + mc = ModelCompressor(quantization_config=quant_config) + mc.decompress_model(new_model) + + decompressed_weights = new_model.layer.weight + assert decompressed_weights.shape == shape + assert not torch.isnan(decompressed_weights).any() + assert not torch.isinf(decompressed_weights).any() + threshold = torch.std(torch.rand(shape) - torch.rand(shape)) + assert torch.std(biased_weights - decompressed_weights) < threshold \ No newline at end of file