Skip to content

Commit db53247

Browse files
committed
refactor
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 8973328 commit db53247

File tree

2 files changed

+31
-14
lines changed

2 files changed

+31
-14
lines changed

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,22 @@
1414

1515

1616
import logging
17+
from enum import Enum
1718
from typing import Optional, Tuple
1819

1920
import torch
20-
from compressed_tensors.quantization import (
21+
from compressed_tensors.quantization.lifecycle.forward import (
22+
wrap_module_forward_quantized,
23+
)
24+
from compressed_tensors.quantization.quant_args import (
2125
FP8_E4M3_DATA,
2226
ActivationOrdering,
2327
DynamicType,
24-
KVCacheScaleType,
2528
QuantizationArgs,
26-
QuantizationMetadata,
27-
QuantizationScheme,
28-
QuantizationStatus,
2929
QuantizationStrategy,
3030
)
31-
from compressed_tensors.quantization.lifecycle.forward import (
32-
wrap_module_forward_quantized,
33-
)
31+
from compressed_tensors.quantization.quant_config import QuantizationStatus
32+
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
3433
from compressed_tensors.quantization.utils import (
3534
is_fp4,
3635
is_kv_cache_quant_scheme,
@@ -54,17 +53,21 @@
5453
_LOGGER = logging.getLogger(__name__)
5554

5655

56+
class KVCacheScaleType(Enum):
57+
KEY = "k_scale"
58+
VALUE = "v_scale"
59+
60+
5761
def initialize_module_for_quantization(
5862
module: Module,
5963
scheme: Optional[QuantizationScheme] = None,
6064
force_zero_point: bool = True,
6165
):
6266
"""
63-
Attaches appropriate scales, zero points, and observers to a layer
64-
given its target quantization scheme.
67+
attaches appropriate scales, zero points, and observers to a layer
68+
given its target quantization scheme
6569
66-
Previously initialized scales and zero points will be removed from
67-
module if they no longer apply to the scheme
70+
apply to full model with `model.apply(initialize_module_for_quantization)`
6871
6972
:param module: module to set for calibration
7073
:param scheme: scheme to use for quantization. if None is provided,
@@ -77,8 +80,6 @@ def initialize_module_for_quantization(
7780
if scheme is None:
7881
return
7982

80-
QuantizationMetadata.clear_all_qparams(module)
81-
8283
if is_attention_module(module):
8384
# quantized actions based on calltime status
8485
_initialize_attn_scales(module)

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434

3535
__all__ = [
36+
"infer_quantization_status",
3637
"is_module_quantized",
3738
"is_model_quantized",
3839
"module_type",
@@ -235,6 +236,21 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
235236
return q_min, q_max
236237

237238

239+
def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa
240+
"""
241+
Checks the quantization status of a model. Assumes all modules in the model have
242+
the same status, so only the first quantized model is checked.
243+
244+
:param model: model to check quantization status for
245+
:return: quantization status if the model is quantized, otherwise None
246+
"""
247+
for module in model.modules():
248+
status = getattr(module, "quantization_status", None)
249+
if status is not None:
250+
return status
251+
return None
252+
253+
238254
def is_module_quantized(module: Module) -> bool:
239255
"""
240256
Check if a module is quantized, based on the existence of a non-empty quantization

0 commit comments

Comments
 (0)