14
14
15
15
16
16
import logging
17
- from enum import Enum
18
17
from typing import Optional , Tuple
19
18
20
19
import torch
21
- from compressed_tensors .quantization .lifecycle .forward import (
22
- wrap_module_forward_quantized ,
23
- )
24
- from compressed_tensors .quantization .quant_args import (
20
+ from compressed_tensors .quantization import (
25
21
FP8_E4M3_DATA ,
26
22
ActivationOrdering ,
27
23
DynamicType ,
24
+ KVCacheScaleType ,
28
25
QuantizationArgs ,
26
+ QuantizationMetadata ,
27
+ QuantizationScheme ,
28
+ QuantizationStatus ,
29
29
QuantizationStrategy ,
30
30
)
31
- from compressed_tensors .quantization .quant_config import QuantizationStatus
32
- from compressed_tensors .quantization .quant_scheme import QuantizationScheme
31
+ from compressed_tensors .quantization .lifecycle .forward import (
32
+ wrap_module_forward_quantized ,
33
+ )
33
34
from compressed_tensors .quantization .utils import (
34
35
is_fp4 ,
35
36
is_kv_cache_quant_scheme ,
53
54
_LOGGER = logging .getLogger (__name__ )
54
55
55
56
56
- class KVCacheScaleType (Enum ):
57
- KEY = "k_scale"
58
- VALUE = "v_scale"
59
-
60
-
61
57
def initialize_module_for_quantization (
62
58
module : Module ,
63
59
scheme : Optional [QuantizationScheme ] = None ,
64
60
force_zero_point : bool = True ,
65
61
):
66
62
"""
67
- attaches appropriate scales, zero points, and observers to a layer
68
- given its target quantization scheme
63
+ Attaches appropriate scales, zero points, and observers to a layer
64
+ given its target quantization scheme.
69
65
70
- apply to full model with `model.apply(initialize_module_for_quantization)`
66
+ Previously initialized scales and zero points will be removed from
67
+ module if they no longer apply to the scheme
71
68
72
69
:param module: module to set for calibration
73
70
:param scheme: scheme to use for quantization. if None is provided,
@@ -80,6 +77,8 @@ def initialize_module_for_quantization(
80
77
if scheme is None :
81
78
return
82
79
80
+ QuantizationMetadata .clear_all_qparams (module )
81
+
83
82
if is_attention_module (module ):
84
83
# quantized actions based on calltime status
85
84
_initialize_attn_scales (module )
0 commit comments