Skip to content

Commit fc2e102

Browse files
lifecycle updates for overwriting config
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent cc27568 commit fc2e102

File tree

1 file changed

+18
-19
lines changed
  • src/compressed_tensors/quantization/lifecycle

1 file changed

+18
-19
lines changed

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -206,34 +206,33 @@ def process_kv_cache_config(
206206
return config
207207

208208

209-
def apply_quantization_status(model: Module, status: QuantizationStatus):
209+
def apply_quantization_status(module: Module, status: QuantizationStatus):
210210
"""
211211
Applies in place the quantization lifecycle up to the given status
212212
213-
:param model: model to apply quantization to
213+
:param module: module to apply quantization to
214214
:param status: status to update the module to
215215
"""
216216

217-
if status >= QuantizationStatus.INITIALIZED:
218-
force_zero_point_init = status != QuantizationStatus.COMPRESSED
219-
220-
# When decompressing, we set the scale_dtype as the model's dtype
221-
# This is because the normal workflow of using the weight's dtype
222-
# will be incorrect as the model weight will be compressed
223-
# Therfore, use the dtype set by the user using the PretrainedModel
224-
scale_dtype = None
225-
if status == QuantizationStatus.FROZEN:
226-
if hasattr(model, "dtype"):
227-
scale_dtype = model.dtype
228-
229-
model.apply(
230-
lambda module: initialize_module_for_quantization(
231-
module, force_zero_point=force_zero_point_init, scale_dtype=scale_dtype
232-
)
217+
force_zero_point_init = status != QuantizationStatus.COMPRESSED
218+
219+
# When decompressing, we set the scale_dtype as the model's dtype
220+
# This is because the normal workflow of using the weight's dtype
221+
# will be incorrect as the model weight will be compressed
222+
# Therfore, use the dtype set by the user using the PretrainedModel
223+
scale_dtype = None
224+
if status == QuantizationStatus.FROZEN:
225+
if hasattr(module, "dtype"):
226+
scale_dtype = module.dtype
227+
228+
module.apply(
229+
lambda module: initialize_module_for_quantization(
230+
module, force_zero_point=force_zero_point_init, scale_dtype=scale_dtype
233231
)
232+
)
234233

235234
if status >= QuantizationStatus.COMPRESSED:
236-
model.apply(compress_quantized_weights)
235+
module.apply(compress_quantized_weights)
237236

238237

239238
@deprecated(

0 commit comments

Comments
 (0)