14
14
15
15
16
16
import logging
17
- import math
18
- import warnings
19
17
from enum import Enum
20
- from typing import Optional
18
+ from typing import Optional , Tuple
21
19
22
20
import torch
23
21
from compressed_tensors .quantization .lifecycle .forward import (
32
30
)
33
31
from compressed_tensors .quantization .quant_config import QuantizationStatus
34
32
from compressed_tensors .quantization .quant_scheme import QuantizationScheme
35
- from compressed_tensors .quantization .utils import is_fp4 , is_kv_cache_quant_scheme
33
+ from compressed_tensors .quantization .utils import (
34
+ is_fp4 ,
35
+ is_kv_cache_quant_scheme ,
36
+ strict_divide ,
37
+ )
36
38
from compressed_tensors .utils import (
37
39
disable_hf_hook ,
38
40
get_execution_device ,
@@ -102,7 +104,7 @@ def initialize_module_for_quantization(
102
104
if scheme .input_activations is not None :
103
105
base_name = "input"
104
106
args = scheme .input_activations
105
- observed_shape = weight .shape [ - 1 :]
107
+ observed_shape = ( 1 , weight .size ( - 1 ))
106
108
observed_dtype = weight .dtype
107
109
108
110
if scheme .weights is not None :
@@ -148,7 +150,7 @@ def _initialize_scale_zero_point(
148
150
module : Module ,
149
151
base_name : str ,
150
152
quantization_args : QuantizationArgs ,
151
- observed_shape : torch . Size ,
153
+ observed_shape : Tuple [ int ] ,
152
154
observed_dtype : torch .dtype ,
153
155
force_zero_point : bool = True ,
154
156
):
@@ -191,8 +193,8 @@ def _initialize_scale_zero_point(
191
193
raise ValueError ("Group quant requires at least 1 observed dimension" )
192
194
193
195
group_size = quantization_args .group_size
194
- num_groups = _strict_divide (observed_shape [- 1 ], group_size , strategy )
195
- expected_shape = (num_groups , group_size )
196
+ num_groups = strict_divide (observed_shape [- 1 ], group_size , strategy )
197
+ expected_shape = (* observed_shape [: - 1 ], num_groups )
196
198
197
199
# initialize activation ordering if applicable
198
200
if actorder == ActivationOrdering .GROUP :
@@ -208,8 +210,8 @@ def _initialize_scale_zero_point(
208
210
raise ValueError ("Block quant requires at least 2 observed dimensions" )
209
211
210
212
block_structure = quantization_args .block_structure
211
- num_rows = _strict_divide (observed_shape [- 2 ], block_structure [- 2 ], strategy )
212
- num_cols = _strict_divide (observed_shape [- 1 ], block_structure [- 1 ], strategy )
213
+ num_rows = strict_divide (observed_shape [- 2 ], block_structure [- 2 ], strategy )
214
+ num_cols = strict_divide (observed_shape [- 1 ], block_structure [- 1 ], strategy )
213
215
expected_shape = (num_rows , num_cols )
214
216
215
217
# 2. Identify quantization scale and zp dtype
@@ -264,16 +266,3 @@ def _initialize_attn_scales(module: Module) -> None:
264
266
requires_grad = False ,
265
267
)
266
268
register_offload_parameter (module , KVCacheScaleType .VALUE .value , init_scale )
267
-
268
-
269
- def _strict_divide (observed : int , divisor : int , strategy : QuantizationStrategy ) -> int :
270
- out = observed // divisor
271
- if out * divisor != observed :
272
- raise ValueError (
273
- f"{ strategy } quantization strategy requires strict division of "
274
- f"weight/activation size { observed } and group/block size { divisor } . "
275
- "consider reducing the group/block size or ignoring modules with weights "
276
- f"not divisible by { divisor } "
277
- )
278
-
279
- return out
0 commit comments