21
21
22
22
import torch
23
23
from compressed_tensors .config import CompressionFormat
24
- from compressed_tensors .quantization .lifecycle .compressed import (
25
- compress_quantized_weights ,
26
- )
27
24
from compressed_tensors .quantization .lifecycle .initialize import (
28
25
initialize_module_for_quantization ,
29
26
)
35
32
from compressed_tensors .quantization .quant_scheme import QuantizationScheme
36
33
from compressed_tensors .quantization .utils import (
37
34
KV_CACHE_TARGETS ,
38
- infer_quantization_status ,
39
35
is_kv_cache_quant_scheme ,
40
36
)
41
37
from compressed_tensors .utils .helpers import deprecated , replace_module
49
45
__all__ = [
50
46
"load_pretrained_quantization_parameters" ,
51
47
"apply_quantization_config" ,
52
- "apply_quantization_status" ,
53
48
"find_name_or_class_matches" ,
54
49
]
55
50
@@ -154,20 +149,27 @@ def apply_quantization_config(
154
149
155
150
# replace with run compressed if applicable
156
151
# FUTURE: move this to model compressor
157
- if isinstance (submodule , torch .nn .Linear ) and run_compressed :
158
- format = config .format
159
- if format != CompressionFormat .dense .value :
160
- if isinstance (submodule , torch .nn .Linear ):
161
- # TODO: expand to more module types
162
- compressed_linear = CompressedLinear .from_linear (
163
- submodule ,
164
- quantization_scheme = scheme ,
165
- quantization_format = format ,
166
- )
167
- replace_module (model , name , compressed_linear )
168
-
169
- # apply current quantization status across all targeted layers
170
- apply_quantization_status (model , config .quantization_status )
152
+ if (
153
+ run_compressed
154
+ and isinstance (submodule , torch .nn .Linear )
155
+ and config .format != CompressionFormat .dense .value
156
+ ):
157
+ # TODO: expand to more module types
158
+ compressed_linear = CompressedLinear .from_linear (
159
+ submodule ,
160
+ quantization_scheme = scheme ,
161
+ quantization_format = config .format ,
162
+ )
163
+ replace_module (model , name , compressed_linear )
164
+
165
+ else :
166
+ initialize_module_for_quantization (
167
+ submodule ,
168
+ force_zero_point = config .quantization_status
169
+ != QuantizationStatus .COMPRESSED ,
170
+ )
171
+
172
+ submodule .quantization_status = config .quantization_status
171
173
172
174
173
175
def process_quantization_config (config : QuantizationConfig ) -> QuantizationConfig :
@@ -206,29 +208,6 @@ def process_kv_cache_config(
206
208
return config
207
209
208
210
209
- def apply_quantization_status (model : Module , status : QuantizationStatus ):
210
- """
211
- Applies in place the quantization lifecycle up to the given status
212
-
213
- :param model: model to apply quantization to
214
- :param status: status to update the module to
215
- """
216
-
217
- current_status = infer_quantization_status (model )
218
-
219
- if status >= QuantizationStatus .INITIALIZED > current_status :
220
- force_zero_point_init = status != QuantizationStatus .COMPRESSED
221
-
222
- model .apply (
223
- lambda module : initialize_module_for_quantization (
224
- module , force_zero_point = force_zero_point_init
225
- )
226
- )
227
-
228
- if current_status < status >= QuantizationStatus .COMPRESSED > current_status :
229
- model .apply (compress_quantized_weights )
230
-
231
-
232
211
@deprecated (
233
212
message = "This function is deprecated and will be removed in a future release."
234
213
"Please use `match_targets` from `compressed_tensors.utils.match` instead."
@@ -254,14 +233,6 @@ def find_name_or_class_matches(
254
233
return match_targets (name , module , targets )
255
234
256
235
257
- def _infer_status (model : Module ) -> Optional [QuantizationStatus ]:
258
- for module in model .modules ():
259
- status = getattr (module , "quantization_status" , None )
260
- if status is not None :
261
- return status
262
- return None
263
-
264
-
265
236
def _load_quant_args_from_mapping (
266
237
base_name : str , module_name : str , module : Module , mapping : Dict
267
238
):
0 commit comments