|
| 1 | +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, |
| 10 | +# software distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +""" |
| 16 | +End-to-end tests for asymmetric quantization with zero-point decompression. |
| 17 | +""" |
| 18 | + |
| 19 | +import shutil |
| 20 | +import tempfile |
| 21 | +from pathlib import Path |
| 22 | + |
| 23 | +import pytest |
| 24 | +import torch |
| 25 | +from compressed_tensors import PackedQuantizationCompressor |
| 26 | +from compressed_tensors.quantization import ( |
| 27 | + QuantizationArgs, |
| 28 | + QuantizationConfig, |
| 29 | + QuantizationScheme, |
| 30 | + QuantizationStrategy, |
| 31 | + apply_quantization_config, |
| 32 | +) |
| 33 | +from compressed_tensors.quantization.lifecycle.forward import fake_quantize |
| 34 | +from safetensors.torch import save_file |
| 35 | +from torch.nn import Linear, Module, Sequential |
| 36 | + |
| 37 | + |
| 38 | +class SimpleModel(Module): |
| 39 | + """Simple model for testing""" |
| 40 | + def __init__(self, input_dim=512, hidden_dim=256, output_dim=128): |
| 41 | + super().__init__() |
| 42 | + self.layer1 = Linear(input_dim, hidden_dim, bias=False) |
| 43 | + self.layer2 = Linear(hidden_dim, output_dim, bias=False) |
| 44 | + |
| 45 | + def forward(self, x): |
| 46 | + x = self.layer1(x) |
| 47 | + x = torch.relu(x) |
| 48 | + x = self.layer2(x) |
| 49 | + return x |
| 50 | + |
| 51 | + |
| 52 | +def create_asymmetric_quant_config( |
| 53 | + num_bits=4, |
| 54 | + strategy=QuantizationStrategy.GROUP, |
| 55 | + group_size=128 |
| 56 | +) -> QuantizationConfig: |
| 57 | + """Create an asymmetric quantization config""" |
| 58 | + config_groups = { |
| 59 | + "group_1": QuantizationScheme( |
| 60 | + targets=["Linear"], |
| 61 | + weights=QuantizationArgs( |
| 62 | + num_bits=num_bits, |
| 63 | + strategy=strategy.value, |
| 64 | + group_size=group_size if strategy == QuantizationStrategy.GROUP else None, |
| 65 | + symmetric=False, |
| 66 | + ), |
| 67 | + ), |
| 68 | + } |
| 69 | + return QuantizationConfig(config_groups=config_groups) |
| 70 | + |
| 71 | + |
| 72 | +@pytest.mark.parametrize( |
| 73 | + "strategy,group_size", |
| 74 | + [ |
| 75 | + (QuantizationStrategy.GROUP, 128), |
| 76 | + (QuantizationStrategy.CHANNEL, None), |
| 77 | + ], |
| 78 | +) |
| 79 | +def test_end_to_end_asymmetric_quantization(strategy, group_size): |
| 80 | + """ |
| 81 | + Test end-to-end workflow: quantize -> compress -> save -> load -> decompress -> use |
| 82 | + """ |
| 83 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 84 | + tmp_path = Path(tmp_dir) |
| 85 | + |
| 86 | + model = SimpleModel() |
| 87 | + original_weights = { |
| 88 | + "layer1": model.layer1.weight.clone(), |
| 89 | + "layer2": model.layer2.weight.clone(), |
| 90 | + } |
| 91 | + |
| 92 | + quant_config = create_asymmetric_quant_config( |
| 93 | + num_bits=4, |
| 94 | + strategy=strategy, |
| 95 | + group_size=group_size |
| 96 | + ) |
| 97 | + apply_quantization_config(model, quant_config) |
| 98 | + |
| 99 | + for name, module in model.named_modules(): |
| 100 | + if isinstance(module, Linear): |
| 101 | + weight = module.weight |
| 102 | + if strategy == QuantizationStrategy.CHANNEL: |
| 103 | + scale_shape = (weight.shape[0], 1) |
| 104 | + else: |
| 105 | + scale_shape = (weight.shape[0], weight.shape[1] // group_size) |
| 106 | + |
| 107 | + module.weight_scale = torch.nn.Parameter( |
| 108 | + torch.rand(scale_shape) * 0.1, |
| 109 | + requires_grad=False |
| 110 | + ) |
| 111 | + module.weight_zero_point = torch.nn.Parameter( |
| 112 | + torch.randint(-8, 8, scale_shape, dtype=torch.int8), |
| 113 | + requires_grad=False |
| 114 | + ) |
| 115 | + |
| 116 | + compressor = PackedQuantizationCompressor(config=quant_config) |
| 117 | + quantized_modules_to_scheme = { |
| 118 | + "layer1": quant_config.config_groups["group_1"], |
| 119 | + "layer2": quant_config.config_groups["group_1"], |
| 120 | + } |
| 121 | + |
| 122 | + state_dict = model.state_dict() |
| 123 | + compressed_state_dict = compressor.compress( |
| 124 | + state_dict, names_to_scheme=quantized_modules_to_scheme |
| 125 | + ) |
| 126 | + |
| 127 | + assert "layer1.weight_zero_point" in compressed_state_dict |
| 128 | + assert "layer2.weight_zero_point" in compressed_state_dict |
| 129 | + assert compressed_state_dict["layer1.weight_zero_point"].dtype == torch.int32 |
| 130 | + assert compressed_state_dict["layer2.weight_zero_point"].dtype == torch.int32 |
| 131 | + |
| 132 | + save_file(compressed_state_dict, tmp_path / "model.safetensors") |
| 133 | + |
| 134 | + reconstructed_gen = compressor.decompress( |
| 135 | + tmp_path, names_to_scheme=quantized_modules_to_scheme |
| 136 | + ) |
| 137 | + |
| 138 | + reconstructed_weights = {} |
| 139 | + for module_name, module_data in reconstructed_gen: |
| 140 | + reconstructed_weights[module_name] = module_data |
| 141 | + |
| 142 | + assert "layer1" in reconstructed_weights |
| 143 | + assert "layer2" in reconstructed_weights |
| 144 | + assert "weight" in reconstructed_weights["layer1"] |
| 145 | + assert "weight" in reconstructed_weights["layer2"] |
| 146 | + |
| 147 | + assert reconstructed_weights["layer1"]["weight"].shape == original_weights["layer1"].shape |
| 148 | + assert reconstructed_weights["layer2"]["weight"].shape == original_weights["layer2"].shape |
| 149 | + |
| 150 | + new_model = SimpleModel() |
| 151 | + new_model.layer1.weight.data = reconstructed_weights["layer1"]["weight"] |
| 152 | + new_model.layer2.weight.data = reconstructed_weights["layer2"]["weight"] |
| 153 | + |
| 154 | + test_input = torch.randn(1, 512) |
| 155 | + with torch.no_grad(): |
| 156 | + output = new_model(test_input) |
| 157 | + |
| 158 | + assert output.shape == (1, 128) |
| 159 | + assert not torch.isnan(output).any() |
| 160 | + assert not torch.isinf(output).any() |
| 161 | + |
| 162 | + |
| 163 | +@pytest.mark.parametrize("num_bits", [4, 8]) |
| 164 | +def test_asymmetric_quantization_accuracy(num_bits): |
| 165 | + """ |
| 166 | + Test that asymmetric quantization with zero-point preserves accuracy better |
| 167 | + than symmetric quantization for biased weight distributions. |
| 168 | + """ |
| 169 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 170 | + tmp_path = Path(tmp_dir) |
| 171 | + |
| 172 | + shape = (256, 512) |
| 173 | + weights = torch.randn(shape) + 2.0 |
| 174 | + |
| 175 | + quant_config = create_asymmetric_quant_config( |
| 176 | + num_bits=num_bits, |
| 177 | + strategy=QuantizationStrategy.GROUP, |
| 178 | + group_size=128 |
| 179 | + ) |
| 180 | + |
| 181 | + group_size = 128 |
| 182 | + num_groups = shape[1] // group_size |
| 183 | + scale_shape = (shape[0], num_groups) |
| 184 | + |
| 185 | + scales = torch.rand(scale_shape) * 0.1 |
| 186 | + zero_points = torch.randint(-2**(num_bits-1), 2**(num_bits-1), scale_shape, dtype=torch.int8) |
| 187 | + |
| 188 | + state_dict = { |
| 189 | + "layer.weight": weights, |
| 190 | + "layer.weight_scale": scales, |
| 191 | + "layer.weight_zero_point": zero_points, |
| 192 | + } |
| 193 | + |
| 194 | + compressor = PackedQuantizationCompressor(config=quant_config) |
| 195 | + quantized_modules_to_scheme = {"layer": quant_config.config_groups["group_1"]} |
| 196 | + |
| 197 | + compressed_state_dict = compressor.compress( |
| 198 | + state_dict.copy(), names_to_scheme=quantized_modules_to_scheme |
| 199 | + ) |
| 200 | + |
| 201 | + save_file(compressed_state_dict, tmp_path / "model.safetensors") |
| 202 | + |
| 203 | + reconstructed_gen = compressor.decompress( |
| 204 | + tmp_path, names_to_scheme=quantized_modules_to_scheme |
| 205 | + ) |
| 206 | + |
| 207 | + reconstructed = {} |
| 208 | + for module_name, module_data in reconstructed_gen: |
| 209 | + reconstructed[module_name] = module_data |
| 210 | + |
| 211 | + assert "layer" in reconstructed |
| 212 | + assert "weight" in reconstructed["layer"] |
| 213 | + assert reconstructed["layer"]["weight"].shape == shape |
| 214 | + |
| 215 | + decompressed_weights = reconstructed["layer"]["weight"] |
| 216 | + assert not torch.isnan(decompressed_weights).any() |
| 217 | + assert not torch.isinf(decompressed_weights).any() |
| 218 | + |
| 219 | + assert decompressed_weights.abs().max() < 100 |
| 220 | + assert decompressed_weights.abs().max() > 0.01 |
| 221 | + |
| 222 | + |
| 223 | +if __name__ == "__main__": |
| 224 | + test_end_to_end_asymmetric_quantization(QuantizationStrategy.GROUP, 128) |
| 225 | + test_end_to_end_asymmetric_quantization(QuantizationStrategy.CHANNEL, None) |
| 226 | + test_asymmetric_quantization_accuracy(4) |
| 227 | + test_asymmetric_quantization_accuracy(8) |
| 228 | + print("All tests passed!") |
0 commit comments