Skip to content

Commit 6ba47e5

Browse files
clear previously initialized qparams
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent fb32778 commit 6ba47e5

File tree

3 files changed

+31
-12
lines changed

3 files changed

+31
-12
lines changed

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -243,14 +243,6 @@ def find_name_or_class_matches(
243243
return match_targets(name, module, targets)
244244

245245

246-
def _infer_status(model: Module) -> Optional[QuantizationStatus]:
247-
for module in model.modules():
248-
status = getattr(module, "quantization_status", None)
249-
if status is not None:
250-
return status
251-
return None
252-
253-
254246
def _load_quant_args_from_mapping(
255247
base_name: str, module_name: str, module: Module, mapping: Dict
256248
):

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
3434
from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme
3535
from compressed_tensors.utils import (
36+
delete_offload_parameter,
3637
disable_hf_hook,
3738
get_execution_device,
3839
register_offload_parameter,
@@ -61,10 +62,11 @@ def initialize_module_for_quantization(
6162
force_zero_point: bool = True,
6263
):
6364
"""
64-
attaches appropriate scales, zero points, and observers to a layer
65-
given its target quantization scheme
65+
Attaches appropriate scales, zero points, and observers to a layer
66+
given its target quantization scheme.
6667
67-
apply to full model with `model.apply(initialize_module_for_quantization)`
68+
Previously initialized scales and zero points will be removed from
69+
module if they no longer apply to the scheme
6870
6971
:param module: module to set for calibration
7072
:param scheme: scheme to use for quantization. if None is provided,
@@ -73,6 +75,8 @@ def initialize_module_for_quantization(
7375
:param force_zero_point: whether to force initialization of a zero point for
7476
symmetric quantization
7577
"""
78+
_clear_all_qparams(module)
79+
7680
# TODO: don't initialize parameters when running decompression
7781
scheme = scheme or getattr(module, "quantization_scheme", None)
7882
if scheme is None:
@@ -134,6 +138,29 @@ def is_attention_module(module: Module):
134138
)
135139

136140

141+
def _clear_all_qparams(
142+
module: Module,
143+
):
144+
"""
145+
Clear all previously registered quantization parameters from module
146+
147+
:param module: module to clear qparams from
148+
"""
149+
keys = [KVCacheScaleType.KEY.value, KVCacheScaleType.VALUE.value] + [
150+
f"{base_name}_{suffix}"
151+
for base_name in ("input", "weight", "output")
152+
for suffix in (
153+
"global_scale",
154+
"scale",
155+
"zero_point",
156+
"g_idx",
157+
)
158+
]
159+
for key in keys:
160+
if hasattr(module, key):
161+
delete_offload_parameter(module, key)
162+
163+
137164
def _initialize_scale_zero_point(
138165
module: Module,
139166
base_name: str,

tests/test_quantization/lifecycle/test_apply.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def get_sample_tinyllama_quant_config(
265265
[("Linear", "re:.*foobarbaz"), True],
266266
],
267267
)
268-
def test_apply_quantization_status(caplog, target, should_raise_warning):
268+
def test_apply_quantization_config(caplog, target, should_raise_warning):
269269
import logging
270270

271271
# load a dense, unquantized tiny llama model

0 commit comments

Comments
 (0)