Skip to content

Commit ca03ad5

Browse files
committed
reduce diff
Signed-off-by: Kyle Sayers <[email protected]>
1 parent db53247 commit ca03ad5

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,23 @@
1414

1515

1616
import logging
17-
from enum import Enum
1817
from typing import Optional, Tuple
1918

2019
import torch
21-
from compressed_tensors.quantization.lifecycle.forward import (
22-
wrap_module_forward_quantized,
23-
)
24-
from compressed_tensors.quantization.quant_args import (
20+
from compressed_tensors.quantization import (
2521
FP8_E4M3_DATA,
2622
ActivationOrdering,
2723
DynamicType,
24+
KVCacheScaleType,
2825
QuantizationArgs,
26+
QuantizationMetadata,
27+
QuantizationScheme,
28+
QuantizationStatus,
2929
QuantizationStrategy,
3030
)
31-
from compressed_tensors.quantization.quant_config import QuantizationStatus
32-
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
31+
from compressed_tensors.quantization.lifecycle.forward import (
32+
wrap_module_forward_quantized,
33+
)
3334
from compressed_tensors.quantization.utils import (
3435
is_fp4,
3536
is_kv_cache_quant_scheme,
@@ -53,21 +54,17 @@
5354
_LOGGER = logging.getLogger(__name__)
5455

5556

56-
class KVCacheScaleType(Enum):
57-
KEY = "k_scale"
58-
VALUE = "v_scale"
59-
60-
6157
def initialize_module_for_quantization(
6258
module: Module,
6359
scheme: Optional[QuantizationScheme] = None,
6460
force_zero_point: bool = True,
6561
):
6662
"""
67-
attaches appropriate scales, zero points, and observers to a layer
68-
given its target quantization scheme
63+
Attaches appropriate scales, zero points, and observers to a layer
64+
given its target quantization scheme.
6965
70-
apply to full model with `model.apply(initialize_module_for_quantization)`
66+
Previously initialized scales and zero points will be removed from
67+
module if they no longer apply to the scheme
7168
7269
:param module: module to set for calibration
7370
:param scheme: scheme to use for quantization. if None is provided,
@@ -80,6 +77,8 @@ def initialize_module_for_quantization(
8077
if scheme is None:
8178
return
8279

80+
QuantizationMetadata.clear_all_qparams(module)
81+
8382
if is_attention_module(module):
8483
# quantized actions based on calltime status
8584
_initialize_attn_scales(module)

0 commit comments

Comments
 (0)