25
25
QuantizationConfig ,
26
26
QuantizationStatus ,
27
27
)
28
- from compressed_tensors .quantization .lifecycle import (
29
- apply_quantization_config ,
30
- apply_quantization_status ,
31
- )
28
+ from compressed_tensors .quantization .lifecycle import apply_quantization_config
32
29
from tests .testing_utils import requires_accelerate
33
30
from transformers import AutoModelForCausalLM
34
31
@@ -105,7 +102,9 @@ def test_target_prioritization(mock_frozen):
105
102
106
103
107
104
def test_apply_quantization_config_tinyllama ():
108
- quant_config = get_sample_tinyllama_quant_config (status = "calibration" )
105
+ quant_config = get_sample_tinyllama_quant_config (
106
+ status = QuantizationStatus .CALIBRATION
107
+ )
109
108
model = get_tinyllama_model ()
110
109
111
110
# check that model is not already quantized
@@ -146,7 +145,8 @@ def test_apply_quantization_config_tinyllama():
146
145
# test quantization compression
147
146
# sample forward pass to fill scales, zps
148
147
model (torch .zeros ((1 , 1 ), dtype = int ), torch .zeros ((1 , 1 ), dtype = int ))
149
- apply_quantization_status (model , QuantizationStatus .COMPRESSED )
148
+ quant_config .quantization_status = QuantizationStatus .COMPRESSED
149
+ apply_quantization_config (model , quant_config )
150
150
for name , module in model .named_modules ():
151
151
if name in quant_config .ignore :
152
152
continue
@@ -157,7 +157,6 @@ def test_apply_quantization_config_tinyllama():
157
157
inputs = True ,
158
158
weights = True ,
159
159
expected_status = QuantizationStatus .COMPRESSED ,
160
- expected_dtype = torch .int8 ,
161
160
)
162
161
163
162
@@ -218,7 +217,9 @@ def get_tinyllama_model():
218
217
)
219
218
220
219
221
- def get_sample_tinyllama_quant_config (status : str = "frozen" ):
220
+ def get_sample_tinyllama_quant_config (
221
+ status : QuantizationStatus = QuantizationStatus .FROZEN ,
222
+ ):
222
223
config_dict = {
223
224
"quant_method" : "compressed-tensors" ,
224
225
"format" : "fakequant" ,
@@ -270,7 +271,7 @@ def test_apply_quantization_status(caplog, target, should_raise_warning):
270
271
# load a dense, unquantized tiny llama model
271
272
model = get_tinyllama_model ()
272
273
quantization_config_dict = {
273
- "quant_method" : "sparseml " ,
274
+ "quant_method" : "compressed-tensors " ,
274
275
"format" : "pack-quantized" ,
275
276
"global_compression_ratio" : None ,
276
277
"config_groups" : {
0 commit comments