Skip to content

Commit 126fc89

Browse files
committed
tests: use in-memory decompress_model; calibrate via fixtures; std-dev similarity; cleanup temp usage
1 parent c0cbb70 commit 126fc89

File tree

2 files changed

+63
-61
lines changed

2 files changed

+63
-61
lines changed

tests/test_compressors/quantized_compressors/test_pack_quant.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import math
1717
import shutil
18+
import tempfile
1819
from collections import OrderedDict
1920

2021
import pytest
@@ -170,12 +171,13 @@ def test_reload_match(tmp_path, num_bits):
170171
)
171172
save_file(compressed_state_dict, tmp_path / "model.safetensors")
172173

173-
reconstructed_dense_gen = compressor.decompress(
174-
tmp_path, names_to_scheme=quantized_modules_to_scheme
175-
)
176174
reconstructed_dense = {}
177-
for name, value in reconstructed_dense_gen:
178-
reconstructed_dense[name] = value
175+
with tempfile.TemporaryDirectory() as _tmp:
176+
reconstructed_dense_gen = compressor.decompress(
177+
tmp_path, names_to_scheme=quantized_modules_to_scheme
178+
)
179+
for name, value in reconstructed_dense_gen:
180+
reconstructed_dense[name] = value
179181

180182
fake_quant_dummy = fake_quantize(
181183
dense_state_dict["dummy.weight"],

tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py

Lines changed: 56 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,12 @@
3030
QuantizationStrategy,
3131
apply_quantization_config,
3232
)
33+
from compressed_tensors.config import CompressionFormat
3334
from compressed_tensors.quantization.lifecycle.forward import fake_quantize
3435
from safetensors.torch import save_file
36+
from compressed_tensors.compressors.model_compressors.model_compressor import (
37+
ModelCompressor,
38+
)
3539
from torch.nn import Linear, Module, Sequential
3640

3741

@@ -90,15 +94,17 @@ def test_end_to_end_asymmetric_quantization(
9094

9195
model = SimpleModel()
9296
original_weights = {
93-
"layer1": model.layer1.weight.clone(),
94-
"layer2": model.layer2.weight.clone(),
97+
"layer1": model.layer1.weight.detach().clone(),
98+
"layer2": model.layer2.weight.detach().clone(),
9599
}
96100

97101
quant_config = create_asymmetric_quant_config(
98102
num_bits=4,
99103
strategy=strategy,
100104
group_size=group_size
101105
)
106+
# Set pack-quantized format for ModelCompressor usage
107+
quant_config.format = CompressionFormat.pack_quantized.value
102108
apply_quantization_config(model, quant_config)
103109

104110
if strategy == QuantizationStrategy.GROUP:
@@ -126,35 +132,33 @@ def test_end_to_end_asymmetric_quantization(
126132
assert compressed_state_dict["layer1.weight_zero_point"].dtype == torch.int32
127133
assert compressed_state_dict["layer2.weight_zero_point"].dtype == torch.int32
128134

129-
save_file(compressed_state_dict, tmp_path / "model.safetensors")
130-
131-
reconstructed_gen = compressor.decompress(
132-
tmp_path, names_to_scheme=quantized_modules_to_scheme
133-
)
134-
135-
reconstructed_weights = {}
136-
for module_name, module_data in reconstructed_gen:
137-
reconstructed_weights[module_name] = module_data
138-
139-
assert "layer1" in reconstructed_weights
140-
assert "layer2" in reconstructed_weights
141-
assert "weight" in reconstructed_weights["layer1"]
142-
assert "weight" in reconstructed_weights["layer2"]
143-
144-
assert reconstructed_weights["layer1"]["weight"].shape == original_weights["layer1"].shape
145-
assert reconstructed_weights["layer2"]["weight"].shape == original_weights["layer2"].shape
146-
147135
new_model = SimpleModel()
148-
new_model.layer1.weight.data = reconstructed_weights["layer1"]["weight"]
149-
new_model.layer2.weight.data = reconstructed_weights["layer2"]["weight"]
150-
151-
test_input = torch.randn(1, 512)
152-
with torch.no_grad():
153-
output = new_model(test_input)
154-
155-
assert output.shape == (1, 128)
156-
assert not torch.isnan(output).any()
157-
assert not torch.isinf(output).any()
136+
apply_quantization_config(new_model, quant_config)
137+
138+
for module_name in ["layer1", "layer2"]:
139+
module = getattr(new_model, module_name)
140+
prefix = f"{module_name}."
141+
for key, value in compressed_state_dict.items():
142+
if key.startswith(prefix):
143+
param_name = key[len(prefix):]
144+
if hasattr(module, param_name):
145+
getattr(module, param_name).data = value.clone()
146+
else:
147+
module.register_parameter(
148+
param_name, torch.nn.Parameter(value.clone(), requires_grad=False)
149+
)
150+
151+
mc = ModelCompressor(quantization_config=quant_config)
152+
mc.decompress_model(new_model)
153+
154+
assert new_model.layer1.weight.shape == original_weights["layer1"].shape
155+
assert new_model.layer2.weight.shape == original_weights["layer2"].shape
156+
assert new_model.layer1.weight.dtype.is_floating_point
157+
assert new_model.layer2.weight.dtype.is_floating_point
158+
assert not torch.isnan(new_model.layer1.weight).any()
159+
assert not torch.isnan(new_model.layer2.weight).any()
160+
assert not torch.isinf(new_model.layer1.weight).any()
161+
assert not torch.isinf(new_model.layer2.weight).any()
158162

159163

160164
@pytest.mark.parametrize("num_bits", [4, 8])
@@ -174,6 +178,7 @@ def test_asymmetric_quantization_accuracy(num_bits, mock_per_group_calibration):
174178
strategy=QuantizationStrategy.GROUP,
175179
group_size=128,
176180
)
181+
quant_config.format = CompressionFormat.pack_quantized.value
177182

178183
class SingleLayer(Module):
179184
def __init__(self):
@@ -194,31 +199,26 @@ def __init__(self):
194199
model.state_dict().copy(), names_to_scheme=quantized_modules_to_scheme
195200
)
196201

197-
save_file(compressed_state_dict, tmp_path / "model.safetensors")
198-
199-
reconstructed_gen = compressor.decompress(
200-
tmp_path, names_to_scheme=quantized_modules_to_scheme
201-
)
202-
203-
reconstructed = {}
204-
for module_name, module_data in reconstructed_gen:
205-
reconstructed[module_name] = module_data
206-
207-
assert "layer" in reconstructed
208-
assert "weight" in reconstructed["layer"]
209-
assert reconstructed["layer"]["weight"].shape == shape
210-
211-
decompressed_weights = reconstructed["layer"]["weight"]
202+
new_model = SingleLayer()
203+
apply_quantization_config(new_model, quant_config)
204+
205+
module = new_model.layer
206+
for key, value in compressed_state_dict.items():
207+
if key.startswith("layer."):
208+
param_name = key[len("layer."):]
209+
if hasattr(module, param_name):
210+
getattr(module, param_name).data = value.clone()
211+
else:
212+
module.register_parameter(
213+
param_name, torch.nn.Parameter(value.clone(), requires_grad=False)
214+
)
215+
216+
mc = ModelCompressor(quantization_config=quant_config)
217+
mc.decompress_model(new_model)
218+
219+
decompressed_weights = new_model.layer.weight
220+
assert decompressed_weights.shape == shape
212221
assert not torch.isnan(decompressed_weights).any()
213222
assert not torch.isinf(decompressed_weights).any()
214-
215-
assert decompressed_weights.abs().max() < 100
216-
assert decompressed_weights.abs().max() > 0.01
217-
218-
219-
if __name__ == "__main__":
220-
test_end_to_end_asymmetric_quantization(QuantizationStrategy.GROUP, 128)
221-
test_end_to_end_asymmetric_quantization(QuantizationStrategy.CHANNEL, None)
222-
test_asymmetric_quantization_accuracy(4)
223-
test_asymmetric_quantization_accuracy(8)
224-
print("All tests passed!")
223+
threshold = torch.std(torch.rand(shape) - torch.rand(shape))
224+
assert torch.std(biased_weights - decompressed_weights) < threshold

0 commit comments

Comments
 (0)