14
14
15
15
16
16
import logging
17
+ from enum import Enum
17
18
from typing import Optional , Tuple
18
19
19
20
import torch
20
- from compressed_tensors .quantization import (
21
+ from compressed_tensors .quantization .lifecycle .forward import (
22
+ wrap_module_forward_quantized ,
23
+ )
24
+ from compressed_tensors .quantization .quant_args import (
21
25
FP8_E4M3_DATA ,
22
26
ActivationOrdering ,
23
27
DynamicType ,
24
- KVCacheScaleType ,
25
28
QuantizationArgs ,
26
- QuantizationMetadata ,
27
- QuantizationScheme ,
28
- QuantizationStatus ,
29
29
QuantizationStrategy ,
30
30
)
31
- from compressed_tensors .quantization .lifecycle .forward import (
32
- wrap_module_forward_quantized ,
33
- )
31
+ from compressed_tensors .quantization .quant_config import QuantizationStatus
32
+ from compressed_tensors .quantization .quant_scheme import QuantizationScheme
34
33
from compressed_tensors .quantization .utils import (
35
34
is_fp4 ,
36
35
is_kv_cache_quant_scheme ,
54
53
_LOGGER = logging .getLogger (__name__ )
55
54
56
55
56
+ class KVCacheScaleType (Enum ):
57
+ KEY = "k_scale"
58
+ VALUE = "v_scale"
59
+
60
+
57
61
def initialize_module_for_quantization (
58
62
module : Module ,
59
63
scheme : Optional [QuantizationScheme ] = None ,
60
64
force_zero_point : bool = True ,
61
65
):
62
66
"""
63
- Attaches appropriate scales, zero points, and observers to a layer
64
- given its target quantization scheme.
67
+ attaches appropriate scales, zero points, and observers to a layer
68
+ given its target quantization scheme
65
69
66
- Previously initialized scales and zero points will be removed from
67
- module if they no longer apply to the scheme
70
+ apply to full model with `model.apply(initialize_module_for_quantization)`
68
71
69
72
:param module: module to set for calibration
70
73
:param scheme: scheme to use for quantization. if None is provided,
@@ -77,8 +80,6 @@ def initialize_module_for_quantization(
77
80
if scheme is None :
78
81
return
79
82
80
- QuantizationMetadata .clear_all_qparams (module )
81
-
82
83
if is_attention_module (module ):
83
84
# quantized actions based on calltime status
84
85
_initialize_attn_scales (module )
0 commit comments