30
30
QuantizationStrategy ,
31
31
apply_quantization_config ,
32
32
)
33
+ from compressed_tensors .config import CompressionFormat
33
34
from compressed_tensors .quantization .lifecycle .forward import fake_quantize
34
35
from safetensors .torch import save_file
36
+ from compressed_tensors .compressors .model_compressors .model_compressor import (
37
+ ModelCompressor ,
38
+ )
35
39
from torch .nn import Linear , Module , Sequential
36
40
37
41
@@ -90,15 +94,17 @@ def test_end_to_end_asymmetric_quantization(
90
94
91
95
model = SimpleModel ()
92
96
original_weights = {
93
- "layer1" : model .layer1 .weight .clone (),
94
- "layer2" : model .layer2 .weight .clone (),
97
+ "layer1" : model .layer1 .weight .detach (). clone (),
98
+ "layer2" : model .layer2 .weight .detach (). clone (),
95
99
}
96
100
97
101
quant_config = create_asymmetric_quant_config (
98
102
num_bits = 4 ,
99
103
strategy = strategy ,
100
104
group_size = group_size
101
105
)
106
+ # Set pack-quantized format for ModelCompressor usage
107
+ quant_config .format = CompressionFormat .pack_quantized .value
102
108
apply_quantization_config (model , quant_config )
103
109
104
110
if strategy == QuantizationStrategy .GROUP :
@@ -126,35 +132,33 @@ def test_end_to_end_asymmetric_quantization(
126
132
assert compressed_state_dict ["layer1.weight_zero_point" ].dtype == torch .int32
127
133
assert compressed_state_dict ["layer2.weight_zero_point" ].dtype == torch .int32
128
134
129
- save_file (compressed_state_dict , tmp_path / "model.safetensors" )
130
-
131
- reconstructed_gen = compressor .decompress (
132
- tmp_path , names_to_scheme = quantized_modules_to_scheme
133
- )
134
-
135
- reconstructed_weights = {}
136
- for module_name , module_data in reconstructed_gen :
137
- reconstructed_weights [module_name ] = module_data
138
-
139
- assert "layer1" in reconstructed_weights
140
- assert "layer2" in reconstructed_weights
141
- assert "weight" in reconstructed_weights ["layer1" ]
142
- assert "weight" in reconstructed_weights ["layer2" ]
143
-
144
- assert reconstructed_weights ["layer1" ]["weight" ].shape == original_weights ["layer1" ].shape
145
- assert reconstructed_weights ["layer2" ]["weight" ].shape == original_weights ["layer2" ].shape
146
-
147
135
new_model = SimpleModel ()
148
- new_model .layer1 .weight .data = reconstructed_weights ["layer1" ]["weight" ]
149
- new_model .layer2 .weight .data = reconstructed_weights ["layer2" ]["weight" ]
150
-
151
- test_input = torch .randn (1 , 512 )
152
- with torch .no_grad ():
153
- output = new_model (test_input )
154
-
155
- assert output .shape == (1 , 128 )
156
- assert not torch .isnan (output ).any ()
157
- assert not torch .isinf (output ).any ()
136
+ apply_quantization_config (new_model , quant_config )
137
+
138
+ for module_name in ["layer1" , "layer2" ]:
139
+ module = getattr (new_model , module_name )
140
+ prefix = f"{ module_name } ."
141
+ for key , value in compressed_state_dict .items ():
142
+ if key .startswith (prefix ):
143
+ param_name = key [len (prefix ):]
144
+ if hasattr (module , param_name ):
145
+ getattr (module , param_name ).data = value .clone ()
146
+ else :
147
+ module .register_parameter (
148
+ param_name , torch .nn .Parameter (value .clone (), requires_grad = False )
149
+ )
150
+
151
+ mc = ModelCompressor (quantization_config = quant_config )
152
+ mc .decompress_model (new_model )
153
+
154
+ assert new_model .layer1 .weight .shape == original_weights ["layer1" ].shape
155
+ assert new_model .layer2 .weight .shape == original_weights ["layer2" ].shape
156
+ assert new_model .layer1 .weight .dtype .is_floating_point
157
+ assert new_model .layer2 .weight .dtype .is_floating_point
158
+ assert not torch .isnan (new_model .layer1 .weight ).any ()
159
+ assert not torch .isnan (new_model .layer2 .weight ).any ()
160
+ assert not torch .isinf (new_model .layer1 .weight ).any ()
161
+ assert not torch .isinf (new_model .layer2 .weight ).any ()
158
162
159
163
160
164
@pytest .mark .parametrize ("num_bits" , [4 , 8 ])
@@ -174,6 +178,7 @@ def test_asymmetric_quantization_accuracy(num_bits, mock_per_group_calibration):
174
178
strategy = QuantizationStrategy .GROUP ,
175
179
group_size = 128 ,
176
180
)
181
+ quant_config .format = CompressionFormat .pack_quantized .value
177
182
178
183
class SingleLayer (Module ):
179
184
def __init__ (self ):
@@ -194,31 +199,26 @@ def __init__(self):
194
199
model .state_dict ().copy (), names_to_scheme = quantized_modules_to_scheme
195
200
)
196
201
197
- save_file (compressed_state_dict , tmp_path / "model.safetensors" )
198
-
199
- reconstructed_gen = compressor .decompress (
200
- tmp_path , names_to_scheme = quantized_modules_to_scheme
201
- )
202
-
203
- reconstructed = {}
204
- for module_name , module_data in reconstructed_gen :
205
- reconstructed [module_name ] = module_data
206
-
207
- assert "layer" in reconstructed
208
- assert "weight" in reconstructed ["layer" ]
209
- assert reconstructed ["layer" ]["weight" ].shape == shape
210
-
211
- decompressed_weights = reconstructed ["layer" ]["weight" ]
202
+ new_model = SingleLayer ()
203
+ apply_quantization_config (new_model , quant_config )
204
+
205
+ module = new_model .layer
206
+ for key , value in compressed_state_dict .items ():
207
+ if key .startswith ("layer." ):
208
+ param_name = key [len ("layer." ):]
209
+ if hasattr (module , param_name ):
210
+ getattr (module , param_name ).data = value .clone ()
211
+ else :
212
+ module .register_parameter (
213
+ param_name , torch .nn .Parameter (value .clone (), requires_grad = False )
214
+ )
215
+
216
+ mc = ModelCompressor (quantization_config = quant_config )
217
+ mc .decompress_model (new_model )
218
+
219
+ decompressed_weights = new_model .layer .weight
220
+ assert decompressed_weights .shape == shape
212
221
assert not torch .isnan (decompressed_weights ).any ()
213
222
assert not torch .isinf (decompressed_weights ).any ()
214
-
215
- assert decompressed_weights .abs ().max () < 100
216
- assert decompressed_weights .abs ().max () > 0.01
217
-
218
-
219
- if __name__ == "__main__" :
220
- test_end_to_end_asymmetric_quantization (QuantizationStrategy .GROUP , 128 )
221
- test_end_to_end_asymmetric_quantization (QuantizationStrategy .CHANNEL , None )
222
- test_asymmetric_quantization_accuracy (4 )
223
- test_asymmetric_quantization_accuracy (8 )
224
- print ("All tests passed!" )
223
+ threshold = torch .std (torch .rand (shape ) - torch .rand (shape ))
224
+ assert torch .std (biased_weights - decompressed_weights ) < threshold
0 commit comments