@@ -206,34 +206,33 @@ def process_kv_cache_config(
206
206
return config
207
207
208
208
209
- def apply_quantization_status (model : Module , status : QuantizationStatus ):
209
+ def apply_quantization_status (module : Module , status : QuantizationStatus ):
210
210
"""
211
211
Applies in place the quantization lifecycle up to the given status
212
212
213
- :param model: model to apply quantization to
213
+ :param module: module to apply quantization to
214
214
:param status: status to update the module to
215
215
"""
216
216
217
- if status >= QuantizationStatus .INITIALIZED :
218
- force_zero_point_init = status != QuantizationStatus .COMPRESSED
219
-
220
- # When decompressing, we set the scale_dtype as the model's dtype
221
- # This is because the normal workflow of using the weight's dtype
222
- # will be incorrect as the model weight will be compressed
223
- # Therfore, use the dtype set by the user using the PretrainedModel
224
- scale_dtype = None
225
- if status == QuantizationStatus .FROZEN :
226
- if hasattr (model , "dtype" ):
227
- scale_dtype = model .dtype
228
-
229
- model .apply (
230
- lambda module : initialize_module_for_quantization (
231
- module , force_zero_point = force_zero_point_init , scale_dtype = scale_dtype
232
- )
217
+ force_zero_point_init = status != QuantizationStatus .COMPRESSED
218
+
219
+ # When decompressing, we set the scale_dtype as the model's dtype
220
+ # This is because the normal workflow of using the weight's dtype
221
+ # will be incorrect as the model weight will be compressed
222
+ # Therfore, use the dtype set by the user using the PretrainedModel
223
+ scale_dtype = None
224
+ if status == QuantizationStatus .FROZEN :
225
+ if hasattr (module , "dtype" ):
226
+ scale_dtype = module .dtype
227
+
228
+ module .apply (
229
+ lambda module : initialize_module_for_quantization (
230
+ module , force_zero_point = force_zero_point_init , scale_dtype = scale_dtype
233
231
)
232
+ )
234
233
235
234
if status >= QuantizationStatus .COMPRESSED :
236
- model .apply (compress_quantized_weights )
235
+ module .apply (compress_quantized_weights )
237
236
238
237
239
238
@deprecated (
0 commit comments