26
26
from compressed_tensors .quantization .quant_args import (
27
27
FP8_E4M3_DATA ,
28
28
ActivationOrdering ,
29
+ DynamicType ,
29
30
QuantizationArgs ,
30
31
QuantizationStrategy ,
31
32
)
@@ -58,8 +59,8 @@ class KVCacheScaleType(Enum):
58
59
def initialize_module_for_quantization (
59
60
module : Module ,
60
61
scheme : Optional [QuantizationScheme ] = None ,
62
+ force_scale_dtype : Optional [torch .dtype ] = None ,
61
63
force_zero_point : bool = True ,
62
- scale_dtype : Optional [torch .dtype ] = None ,
63
64
):
64
65
"""
65
66
attaches appropriate scales, zero points, and observers to a layer
@@ -76,51 +77,58 @@ def initialize_module_for_quantization(
76
77
:param scale_dtype: dtype to used for the scales, if overriding the
77
78
weight dtype as the scale dtype
78
79
"""
79
- # TODO: don't initialize parameters when running decompression
80
80
scheme = scheme or getattr (module , "quantization_scheme" , None )
81
81
if scheme is None :
82
- # no scheme passed and layer not targeted for quantization - skip
83
82
return
84
83
85
84
if is_attention_module (module ):
86
85
# quantized actions based on calltime status
87
86
_initialize_attn_scales (module )
88
87
89
88
else :
90
- if scheme .input_activations is not None :
91
- _initialize_scale_zero_point (
92
- module ,
93
- "input" ,
94
- scheme .input_activations ,
95
- force_zero_point = force_zero_point ,
96
- scale_dtype = scale_dtype ,
89
+ if not isinstance (module , torch .nn .Linear ):
90
+ _LOGGER .warning (f"Attempting to quantize module of type { type (module )} " )
91
+
92
+ # use weight to determine observed shapes and dtype
93
+ if hasattr (module , "weight" ):
94
+ weight = module .weight
95
+ assert isinstance (weight , torch .Tensor )
96
+ else :
97
+ # Note that a weight is required for both weight and activation
98
+ # quantization in order to know the dtype of activation scales
99
+ _LOGGER .warning (
100
+ f"module type { type (module )} targeted for quantization but "
101
+ f"has no attribute weight, skipping quantization for { type (module )} "
97
102
)
103
+ return
104
+
105
+ if scheme .input_activations is not None :
106
+ base_name = "input"
107
+ args = scheme .input_activations
108
+ observed_shape = weight .shape [- 1 :]
109
+ observed_dtype = force_scale_dtype or weight .dtype
98
110
99
111
if scheme .weights is not None :
100
- if hasattr (module , "weight" ):
101
- weight_shape = None
102
- if isinstance (module , torch .nn .Linear ):
103
- weight_shape = module .weight .shape
104
- _initialize_scale_zero_point (
105
- module ,
106
- "weight" ,
107
- scheme .weights ,
108
- weight_shape = weight_shape ,
109
- force_zero_point = force_zero_point ,
110
- scale_dtype = scale_dtype ,
111
- )
112
- else :
113
- _LOGGER .warning (
114
- f"module type { type (module )} targeted for weight quantization but "
115
- "has no attribute weight, skipping weight quantization "
116
- f"for { type (module )} "
117
- )
112
+ base_name = "weight"
113
+ args = scheme .weights
114
+ observed_shape = weight .shape
115
+ observed_dtype = force_scale_dtype or weight .dtype
118
116
119
117
if scheme .output_activations is not None :
120
- if not is_kv_cache_quant_scheme (scheme ):
121
- _initialize_scale_zero_point (
122
- module , "output" , scheme .output_activations , scale_dtype = scale_dtype
123
- )
118
+ base_name = "output"
119
+ args = scheme .output_activations
120
+ observed_shape = weight .shape [:- 1 ]
121
+ observed_dtype = force_scale_dtype or weight .dtype
122
+
123
+ if not is_kv_cache_quant_scheme (scheme ):
124
+ _initialize_scale_zero_point (
125
+ module ,
126
+ base_name ,
127
+ args ,
128
+ observed_shape = observed_shape ,
129
+ observed_dtype = observed_dtype ,
130
+ force_zero_point = force_zero_point ,
131
+ )
124
132
125
133
module .quantization_scheme = scheme
126
134
module .quantization_status = QuantizationStatus .INITIALIZED
@@ -143,19 +151,21 @@ def _initialize_scale_zero_point(
143
151
module : Module ,
144
152
base_name : str ,
145
153
quantization_args : QuantizationArgs ,
146
- weight_shape : Optional [torch .Size ] = None ,
154
+ observed_shape : torch .Size ,
155
+ observed_dtype : torch .dtype ,
147
156
force_zero_point : bool = True ,
148
- scale_dtype : Optional [torch .dtype ] = None ,
149
157
):
150
- if quantization_args .dynamic is True :
151
- return
158
+ strategy = quantization_args .strategy
159
+ dynamic = quantization_args .dynamic
160
+ actorder = quantization_args .actorder
161
+ device = get_execution_device (module ) # avoid performing intialization ops on cpu
152
162
153
- # initialize on execution device to avoid performing quantized ops on cpu
154
- device = get_execution_device (module )
163
+ # Skip all intialization for fully dynamic quantization
164
+ if dynamic is True :
165
+ return
155
166
156
- # 1. Create global_scales for tensor_group - generates
157
- # a per tensor scale
158
- if quantization_args .strategy == QuantizationStrategy .TENSOR_GROUP :
167
+ # 0. Create global scale for tensor-group quantization
168
+ if strategy == QuantizationStrategy .TENSOR_GROUP :
159
169
init_global_scale = Parameter (
160
170
torch .empty (1 , dtype = torch .float32 , device = device ),
161
171
requires_grad = False ,
@@ -164,56 +174,49 @@ def _initialize_scale_zero_point(
164
174
module , f"{ base_name } _global_scale" , init_global_scale
165
175
)
166
176
167
- # 2. Infer expected scale/zero point shape
168
- if quantization_args .strategy == QuantizationStrategy .TOKEN :
169
- expected_shape = (1 , 1 )
170
- else :
171
- expected_shape = 1
172
-
173
- if base_name == "weight" and weight_shape is not None :
174
- if quantization_args .strategy == QuantizationStrategy .CHANNEL :
175
- # (output_channels, 1) - only for weights
176
- expected_shape = (weight_shape [0 ], 1 )
177
- elif quantization_args .strategy in (
178
- QuantizationStrategy .TENSOR_GROUP ,
179
- QuantizationStrategy .GROUP ,
180
- ):
181
- # GROUP/TENSOR_GROUP for both weights and activations
182
- num_groups = math .ceil (weight_shape [1 ] / quantization_args .group_size )
183
- expected_shape = (weight_shape [0 ], max (num_groups , 1 ))
184
- elif quantization_args .strategy == QuantizationStrategy .BLOCK :
185
- # For block quantization, scale shape should match number of blocks - only
186
- # for weights
187
- if quantization_args .block_structure is None :
188
- raise ValueError (
189
- "Block quantization requires block_structure to be specified"
190
- )
191
- block_height , block_width = quantization_args .block_structure
192
- rows , cols = weight_shape [- 2 ], weight_shape [- 1 ]
193
- num_rows_blocks = math .ceil (rows / block_height )
194
- num_cols_blocks = math .ceil (cols / block_width )
195
-
196
- # Warn if dimensions don't divide evenly
197
- if rows % block_height != 0 or cols % block_width != 0 :
198
- warnings .warn (
199
- f"Block quantization: tensor shape { weight_shape } does not divide"
200
- f"evenly by block structure { quantization_args .block_structure } . "
201
- f"Some blocks will be incomplete which may affect quantization"
202
- "quality." ,
203
- UserWarning ,
204
- )
205
-
206
- expected_shape = (num_rows_blocks , num_cols_blocks )
207
- elif quantization_args .strategy == QuantizationStrategy .BLOCK :
208
- warnings .warn (
209
- f"BLOCK quantization not supported for { base_name } activations. "
210
- f"Falling back to tensor-level quantization." ,
211
- UserWarning ,
212
- )
213
- expected_shape = 1
177
+ # Skip scale/zp initialization for locally dynamic quantization
178
+ if dynamic == DynamicType .LOCAL :
179
+ return
180
+
181
+ # 1. Infer expected scale/zp shape
182
+ if strategy in (QuantizationStrategy .TENSOR , QuantizationStrategy .TOKEN ):
183
+ expected_shape = (1 ,)
184
+
185
+ elif strategy == QuantizationStrategy .CHANNEL :
186
+ if len (observed_shape ) < 1 :
187
+ raise ValueError ("Channel quant requires at least 1 observed dimension" )
188
+
189
+ expected_shape = (observed_shape [- 1 ], 1 )
214
190
215
- # 3. Identify quantization scale and zp dtype
216
- scale_dtype = scale_dtype if scale_dtype is not None else module .weight .dtype
191
+ elif strategy in (QuantizationStrategy .GROUP , QuantizationStrategy .TENSOR_GROUP ):
192
+ assert quantization_args .group_size is not None
193
+ if len (observed_shape ) < 1 :
194
+ raise ValueError ("Group quant requires at least 1 observed dimension" )
195
+
196
+ group_size = quantization_args .group_size
197
+ num_groups = _strict_divide (observed_shape [- 1 ], group_size , strategy )
198
+ expected_shape = (num_groups , group_size )
199
+
200
+ # initialize activation ordering if applicable
201
+ if actorder == ActivationOrdering .GROUP :
202
+ init_g_idx = Parameter (
203
+ torch .full ((observed_shape [- 1 ],), - 1 , device = device , dtype = torch .int ),
204
+ requires_grad = False ,
205
+ )
206
+ register_offload_parameter (module , f"{ base_name } _g_idx" , init_g_idx )
207
+
208
+ elif strategy == QuantizationStrategy .BLOCK :
209
+ assert quantization_args .block_structure is not None
210
+ if len (observed_shape ) < 2 :
211
+ raise ValueError ("Block quant requires at least 2 observed dimensions" )
212
+
213
+ block_structure = quantization_args .block_structure
214
+ num_rows = _strict_divide (observed_shape [- 2 ], block_structure [- 2 ], strategy )
215
+ num_cols = _strict_divide (observed_shape [- 1 ], block_structure [- 1 ], strategy )
216
+ expected_shape = (num_rows , num_cols )
217
+
218
+ # 2. Identify quantization scale and zp dtype
219
+ scale_dtype = observed_dtype
217
220
218
221
if is_fp4 (quantization_args = quantization_args ):
219
222
scale_dtype = zp_dtype = FP8_E4M3_DATA .dtype
@@ -229,14 +232,12 @@ def _initialize_scale_zero_point(
229
232
scale_dtype = torch .float16
230
233
zp_dtype = quantization_args .pytorch_dtype ()
231
234
232
- # 4. Initializes empty scale, zero point, and g_idx parameters for the module
233
- # do not init scales for quantzation_args.dynamic == DynamicType.local
234
- if not quantization_args .dynamic :
235
- init_scale = Parameter (
236
- torch .empty (expected_shape , dtype = scale_dtype , device = device ),
237
- requires_grad = False ,
238
- )
239
- register_offload_parameter (module , f"{ base_name } _scale" , init_scale )
235
+ # 3. Initializes scale/zp for the module
236
+ init_scale = Parameter (
237
+ torch .empty (expected_shape , dtype = scale_dtype , device = device ),
238
+ requires_grad = False ,
239
+ )
240
+ register_offload_parameter (module , f"{ base_name } _scale" , init_scale )
240
241
241
242
if force_zero_point or not quantization_args .symmetric :
242
243
init_zero_point = Parameter (
@@ -245,16 +246,6 @@ def _initialize_scale_zero_point(
245
246
)
246
247
register_offload_parameter (module , f"{ base_name } _zero_point" , init_zero_point )
247
248
248
- # only grouped activation ordering has g_idx
249
- if quantization_args .actorder == ActivationOrdering .GROUP :
250
- g_idx_shape = (weight_shape [1 ],)
251
- g_idx_dtype = torch .int
252
- init_g_idx = Parameter (
253
- torch .full (g_idx_shape , - 1 , device = device , dtype = g_idx_dtype ),
254
- requires_grad = False ,
255
- )
256
- register_offload_parameter (module , f"{ base_name } _g_idx" , init_g_idx )
257
-
258
249
259
250
def _initialize_attn_scales (module : Module ) -> None :
260
251
"""Initlaize k_scale, v_scale for self_attn"""
@@ -276,3 +267,16 @@ def _initialize_attn_scales(module: Module) -> None:
276
267
requires_grad = False ,
277
268
)
278
269
register_offload_parameter (module , KVCacheScaleType .VALUE .value , init_scale )
270
+
271
+
272
+ def _strict_divide (observed : int , divisor : int , strategy : QuantizationStrategy ) -> int :
273
+ out = observed // divisor
274
+ if out * divisor != observed :
275
+ raise ValueError (
276
+ f"{ strategy } quantization strategy requires strict division of "
277
+ f"weight/activation size { observed } and group/block size { divisor } . "
278
+ "consider reducing the group/block size or ignoring modules with weights "
279
+ f"not divisible by { divisor } "
280
+ )
281
+
282
+ return out
0 commit comments