14
14
15
15
16
16
import logging
17
- import math
18
- import warnings
19
- from typing import Optional
17
+ from typing import Optional , Tuple
20
18
21
19
import torch
22
20
from compressed_tensors .quantization import (
23
21
FP8_E4M3_DATA ,
24
22
ActivationOrdering ,
23
+ DynamicType ,
25
24
KVCacheScaleType ,
26
25
QuantizationArgs ,
27
26
QuantizationMetadata ,
32
31
from compressed_tensors .quantization .lifecycle .forward import (
33
32
wrap_module_forward_quantized ,
34
33
)
35
- from compressed_tensors .quantization .utils import is_fp4 , is_kv_cache_quant_scheme
34
+ from compressed_tensors .quantization .utils import (
35
+ is_fp4 ,
36
+ is_kv_cache_quant_scheme ,
37
+ strategy_cdiv ,
38
+ )
36
39
from compressed_tensors .utils import (
37
40
disable_hf_hook ,
38
41
get_execution_device ,
44
47
__all__ = [
45
48
"initialize_module_for_quantization" ,
46
49
"is_attention_module" ,
50
+ "initialize_qparams" ,
47
51
]
48
52
49
53
@@ -69,10 +73,8 @@ def initialize_module_for_quantization(
69
73
:param force_zero_point: whether to force initialization of a zero point for
70
74
symmetric quantization
71
75
"""
72
- # TODO: don't initialize parameters when running decompression
73
76
scheme = scheme or getattr (module , "quantization_scheme" , None )
74
77
if scheme is None :
75
- # no scheme passed and layer not targeted for quantization - skip
76
78
return
77
79
78
80
QuantizationMetadata .clear_all_qparams (module )
@@ -82,38 +84,52 @@ def initialize_module_for_quantization(
82
84
_initialize_attn_scales (module )
83
85
84
86
else :
87
+ if not isinstance (module , torch .nn .Linear ):
88
+ _LOGGER .warning (f"Attempting to quantize module of type { type (module )} " )
89
+
90
+ # use weight to determine observed shapes and dtype
91
+ if hasattr (module , "weight" ):
92
+ weight = module .weight
93
+ assert isinstance (weight , torch .Tensor )
94
+ else :
95
+ # Note that a weight is required for both weight and activation
96
+ # quantization in order to know the dtype of activation scales
97
+ _LOGGER .warning (
98
+ f"module type { type (module )} targeted for quantization but "
99
+ f"has no attribute weight, skipping quantization for { type (module )} "
100
+ )
101
+ return
102
+
85
103
if scheme .input_activations is not None :
86
- _initialize_scale_zero_point (
104
+ initialize_qparams (
87
105
module ,
88
106
"input" ,
89
107
scheme .input_activations ,
108
+ observed_shape = weight .shape [- 1 :],
109
+ observed_dtype = weight .dtype ,
90
110
force_zero_point = force_zero_point ,
91
111
)
92
112
93
113
if scheme .weights is not None :
94
- if hasattr (module , "weight" ):
95
- weight_shape = None
96
- if isinstance (module , torch .nn .Linear ):
97
- weight_shape = module .weight .shape
98
- _initialize_scale_zero_point (
99
- module ,
100
- "weight" ,
101
- scheme .weights ,
102
- weight_shape = weight_shape ,
103
- force_zero_point = force_zero_point ,
104
- )
105
- else :
106
- _LOGGER .warning (
107
- f"module type { type (module )} targeted for weight quantization but "
108
- "has no attribute weight, skipping weight quantization "
109
- f"for { type (module )} "
110
- )
111
-
112
- if scheme .output_activations is not None :
113
- if not is_kv_cache_quant_scheme (scheme ):
114
- _initialize_scale_zero_point (
115
- module , "output" , scheme .output_activations
116
- )
114
+ initialize_qparams (
115
+ module ,
116
+ "weight" ,
117
+ scheme .weights ,
118
+ observed_shape = weight .shape ,
119
+ observed_dtype = weight .dtype ,
120
+ force_zero_point = force_zero_point ,
121
+ )
122
+
123
+ output_is_kv_cache = is_kv_cache_quant_scheme (scheme )
124
+ if scheme .output_activations is not None and not output_is_kv_cache :
125
+ initialize_qparams (
126
+ module ,
127
+ "output" ,
128
+ scheme .output_activations ,
129
+ observed_shape = weight .shape [:- 1 ],
130
+ observed_dtype = weight .dtype ,
131
+ force_zero_point = force_zero_point ,
132
+ )
117
133
118
134
module .quantization_scheme = scheme
119
135
module .quantization_status = QuantizationStatus .INITIALIZED
@@ -132,22 +148,40 @@ def is_attention_module(module: Module):
132
148
)
133
149
134
150
135
- def _initialize_scale_zero_point (
151
+ def initialize_qparams (
136
152
module : Module ,
137
153
base_name : str ,
138
154
quantization_args : QuantizationArgs ,
139
- weight_shape : Optional [torch .Size ] = None ,
155
+ observed_shape : Tuple [int ],
156
+ observed_dtype : torch .dtype ,
140
157
force_zero_point : bool = True ,
141
158
):
142
- if quantization_args .dynamic is True :
143
- return
159
+ """
160
+ Initialize quantization parameters for a given basename according to the passed
161
+ quantization args. The shape and dtype of the observed weight/activation must also
162
+ be provided.
163
+
164
+ Scales will always be initialized. Global scales are initialized depending on args.
165
+ Zero points will be initialized if not symmetric or if `force_zero_point` is True.
166
+
167
+ :param module: module to register qparams to
168
+ :param base_name: base name of qparams, for example "input", "weight", "k", "v"
169
+ :param quantization_args: arguments for quantization
170
+ :param observed_shape: last (right-most) known dimensions of the observed weight/act
171
+ :param observed_dtype: dtype of the observed weight/actt
172
+ :param force_zero_point: force the zero_point parameter to be initialized
173
+ """
174
+ strategy = quantization_args .strategy
175
+ dynamic = quantization_args .dynamic
176
+ actorder = quantization_args .actorder
177
+ device = get_execution_device (module ) # avoid performing intialization ops on cpu
144
178
145
- # initialize on execution device to avoid performing quantized ops on cpu
146
- device = get_execution_device (module )
179
+ # Skip all intialization for fully dynamic quantization
180
+ if dynamic is True :
181
+ return
147
182
148
- # 1. Create global_scales for tensor_group - generates
149
- # a per tensor scale
150
- if quantization_args .strategy == QuantizationStrategy .TENSOR_GROUP :
183
+ # 0. Create global scale for tensor-group quantization
184
+ if strategy == QuantizationStrategy .TENSOR_GROUP :
151
185
init_global_scale = Parameter (
152
186
torch .empty (1 , dtype = torch .float32 , device = device ),
153
187
requires_grad = False ,
@@ -156,56 +190,55 @@ def _initialize_scale_zero_point(
156
190
module , f"{ base_name } _global_scale" , init_global_scale
157
191
)
158
192
159
- # 2. Infer expected scale/zero point shape
160
- if quantization_args .strategy == QuantizationStrategy .TOKEN :
193
+ # Skip scale/zp initialization for locally dynamic quantization
194
+ if dynamic == DynamicType .LOCAL :
195
+ return
196
+
197
+ # 1. Infer expected scale/zp shape
198
+ if strategy == QuantizationStrategy .TENSOR :
199
+ expected_shape = (1 ,)
200
+
201
+ elif strategy == QuantizationStrategy .TOKEN :
161
202
expected_shape = (1 , 1 )
203
+
204
+ elif strategy == QuantizationStrategy .CHANNEL :
205
+ if len (observed_shape ) < 2 :
206
+ raise ValueError ("Channel quant requires at least 2 observed dimensions" )
207
+
208
+ expected_shape = (observed_shape [- 2 ], 1 )
209
+
210
+ elif strategy in (QuantizationStrategy .GROUP , QuantizationStrategy .TENSOR_GROUP ):
211
+ assert quantization_args .group_size is not None
212
+ if len (observed_shape ) < 1 :
213
+ raise ValueError ("Group quant requires at least 1 observed dimension" )
214
+
215
+ group_size = quantization_args .group_size
216
+ num_groups = strategy_cdiv (observed_shape [- 1 ], group_size , strategy )
217
+ expected_shape = (* observed_shape [:- 1 ], num_groups )
218
+
219
+ # initialize activation ordering if applicable
220
+ if actorder == ActivationOrdering .GROUP :
221
+ init_g_idx = Parameter (
222
+ torch .full ((observed_shape [- 1 ],), - 1 , device = device , dtype = torch .int ),
223
+ requires_grad = False ,
224
+ )
225
+ register_offload_parameter (module , f"{ base_name } _g_idx" , init_g_idx )
226
+
227
+ elif strategy == QuantizationStrategy .BLOCK :
228
+ assert quantization_args .block_structure is not None
229
+ if len (observed_shape ) < 2 :
230
+ raise ValueError ("Block quant requires at least 2 observed dimensions" )
231
+
232
+ block_structure = quantization_args .block_structure
233
+ num_rows = strategy_cdiv (observed_shape [- 2 ], block_structure [- 2 ], strategy )
234
+ num_cols = strategy_cdiv (observed_shape [- 1 ], block_structure [- 1 ], strategy )
235
+ expected_shape = (num_rows , num_cols )
236
+
162
237
else :
163
- expected_shape = 1
164
-
165
- if base_name == "weight" and weight_shape is not None :
166
- if quantization_args .strategy == QuantizationStrategy .CHANNEL :
167
- # (output_channels, 1) - only for weights
168
- expected_shape = (weight_shape [0 ], 1 )
169
- elif quantization_args .strategy in (
170
- QuantizationStrategy .TENSOR_GROUP ,
171
- QuantizationStrategy .GROUP ,
172
- ):
173
- # GROUP/TENSOR_GROUP for both weights and activations
174
- num_groups = math .ceil (weight_shape [1 ] / quantization_args .group_size )
175
- expected_shape = (weight_shape [0 ], max (num_groups , 1 ))
176
- elif quantization_args .strategy == QuantizationStrategy .BLOCK :
177
- # For block quantization, scale shape should match number of blocks - only
178
- # for weights
179
- if quantization_args .block_structure is None :
180
- raise ValueError (
181
- "Block quantization requires block_structure to be specified"
182
- )
183
- block_height , block_width = quantization_args .block_structure
184
- rows , cols = weight_shape [- 2 ], weight_shape [- 1 ]
185
- num_rows_blocks = math .ceil (rows / block_height )
186
- num_cols_blocks = math .ceil (cols / block_width )
187
-
188
- # Warn if dimensions don't divide evenly
189
- if rows % block_height != 0 or cols % block_width != 0 :
190
- warnings .warn (
191
- f"Block quantization: tensor shape { weight_shape } does not divide"
192
- f"evenly by block structure { quantization_args .block_structure } . "
193
- f"Some blocks will be incomplete which may affect quantization"
194
- "quality." ,
195
- UserWarning ,
196
- )
197
-
198
- expected_shape = (num_rows_blocks , num_cols_blocks )
199
- elif quantization_args .strategy == QuantizationStrategy .BLOCK :
200
- warnings .warn (
201
- f"BLOCK quantization not supported for { base_name } activations. "
202
- f"Falling back to tensor-level quantization." ,
203
- UserWarning ,
204
- )
205
- expected_shape = 1
238
+ assert False , f"Unknown strategy { strategy } "
206
239
207
- # 3 . Identify quantization scale and zp dtype
208
- scale_dtype = module . weight . dtype
240
+ # 2 . Identify quantization scale and zp dtype
241
+ scale_dtype = observed_dtype
209
242
210
243
if is_fp4 (quantization_args = quantization_args ):
211
244
scale_dtype = zp_dtype = FP8_E4M3_DATA .dtype
@@ -221,14 +254,12 @@ def _initialize_scale_zero_point(
221
254
scale_dtype = torch .bfloat16
222
255
zp_dtype = quantization_args .pytorch_dtype ()
223
256
224
- # 4. Initializes empty scale, zero point, and g_idx parameters for the module
225
- # do not init scales for quantzation_args.dynamic == DynamicType.local
226
- if not quantization_args .dynamic :
227
- init_scale = Parameter (
228
- torch .empty (expected_shape , dtype = scale_dtype , device = device ),
229
- requires_grad = False ,
230
- )
231
- register_offload_parameter (module , f"{ base_name } _scale" , init_scale )
257
+ # 3. Initializes scale/zp for the module
258
+ init_scale = Parameter (
259
+ torch .empty (expected_shape , dtype = scale_dtype , device = device ),
260
+ requires_grad = False ,
261
+ )
262
+ register_offload_parameter (module , f"{ base_name } _scale" , init_scale )
232
263
233
264
if force_zero_point or not quantization_args .symmetric :
234
265
init_zero_point = Parameter (
@@ -237,16 +268,6 @@ def _initialize_scale_zero_point(
237
268
)
238
269
register_offload_parameter (module , f"{ base_name } _zero_point" , init_zero_point )
239
270
240
- # only grouped activation ordering has g_idx
241
- if quantization_args .actorder == ActivationOrdering .GROUP :
242
- g_idx_shape = (weight_shape [1 ],)
243
- g_idx_dtype = torch .int
244
- init_g_idx = Parameter (
245
- torch .full (g_idx_shape , - 1 , device = device , dtype = g_idx_dtype ),
246
- requires_grad = False ,
247
- )
248
- register_offload_parameter (module , f"{ base_name } _g_idx" , init_g_idx )
249
-
250
271
251
272
def _initialize_attn_scales (module : Module ) -> None :
252
273
"""Initlaize k_scale, v_scale for self_attn"""
0 commit comments