26
26
from compressed_tensors .quantization .quant_args import (
27
27
FP8_E4M3_DATA ,
28
28
ActivationOrdering ,
29
+ DynamicType ,
29
30
QuantizationArgs ,
30
31
QuantizationStrategy ,
31
32
)
@@ -73,49 +74,58 @@ def initialize_module_for_quantization(
73
74
:param force_zero_point: whether to force initialization of a zero point for
74
75
symmetric quantization
75
76
"""
76
- # TODO: don't initialize parameters when running decompression
77
77
scheme = scheme or getattr (module , "quantization_scheme" , None )
78
78
if scheme is None :
79
- # no scheme passed and layer not targeted for quantization - skip
80
79
return
81
80
82
81
if is_attention_module (module ):
83
82
# quantized actions based on calltime status
84
83
_initialize_attn_scales (module )
85
84
86
85
else :
87
- if scheme .input_activations is not None :
88
- _initialize_scale_zero_point (
89
- module ,
90
- "input" ,
91
- scheme .input_activations ,
92
- force_zero_point = force_zero_point ,
86
+ if not isinstance (module , torch .nn .Linear ):
87
+ _LOGGER .warning (f"Attempting to quantize module of type { type (module )} " )
88
+
89
+ # use weight to determine observed shapes and dtype
90
+ if hasattr (module , "weight" ):
91
+ weight = module .weight
92
+ assert isinstance (weight , torch .Tensor )
93
+ else :
94
+ # Note that a weight is required for both weight and activation
95
+ # quantization in order to know the dtype of activation scales
96
+ _LOGGER .warning (
97
+ f"module type { type (module )} targeted for quantization but "
98
+ f"has no attribute weight, skipping quantization for { type (module )} "
93
99
)
100
+ return
101
+
102
+ if scheme .input_activations is not None :
103
+ base_name = "input"
104
+ args = scheme .input_activations
105
+ observed_shape = weight .shape [- 1 :]
106
+ observed_dtype = weight .dtype
94
107
95
108
if scheme .weights is not None :
96
- if hasattr (module , "weight" ):
97
- weight_shape = None
98
- if isinstance (module , torch .nn .Linear ):
99
- weight_shape = module .weight .shape
100
- _initialize_scale_zero_point (
101
- module ,
102
- "weight" ,
103
- scheme .weights ,
104
- weight_shape = weight_shape ,
105
- force_zero_point = force_zero_point ,
106
- )
107
- else :
108
- _LOGGER .warning (
109
- f"module type { type (module )} targeted for weight quantization but "
110
- "has no attribute weight, skipping weight quantization "
111
- f"for { type (module )} "
112
- )
109
+ base_name = "weight"
110
+ args = scheme .weights
111
+ observed_shape = weight .shape
112
+ observed_dtype = weight .dtype
113
113
114
114
if scheme .output_activations is not None :
115
- if not is_kv_cache_quant_scheme (scheme ):
116
- _initialize_scale_zero_point (
117
- module , "output" , scheme .output_activations
118
- )
115
+ base_name = "output"
116
+ args = scheme .output_activations
117
+ observed_shape = weight .shape [:- 1 ]
118
+ observed_dtype = weight .dtype
119
+
120
+ if not is_kv_cache_quant_scheme (scheme ):
121
+ _initialize_scale_zero_point (
122
+ module ,
123
+ base_name ,
124
+ args ,
125
+ observed_shape = observed_shape ,
126
+ observed_dtype = observed_dtype ,
127
+ force_zero_point = force_zero_point ,
128
+ )
119
129
120
130
module .quantization_scheme = scheme
121
131
module .quantization_status = QuantizationStatus .INITIALIZED
@@ -138,18 +148,21 @@ def _initialize_scale_zero_point(
138
148
module : Module ,
139
149
base_name : str ,
140
150
quantization_args : QuantizationArgs ,
141
- weight_shape : Optional [torch .Size ] = None ,
151
+ observed_shape : torch .Size ,
152
+ observed_dtype : torch .dtype ,
142
153
force_zero_point : bool = True ,
143
154
):
144
- if quantization_args .dynamic is True :
145
- return
155
+ strategy = quantization_args .strategy
156
+ dynamic = quantization_args .dynamic
157
+ actorder = quantization_args .actorder
158
+ device = get_execution_device (module ) # avoid performing intialization ops on cpu
146
159
147
- # initialize on execution device to avoid performing quantized ops on cpu
148
- device = get_execution_device (module )
160
+ # Skip all intialization for fully dynamic quantization
161
+ if dynamic is True :
162
+ return
149
163
150
- # 1. Create global_scales for tensor_group - generates
151
- # a per tensor scale
152
- if quantization_args .strategy == QuantizationStrategy .TENSOR_GROUP :
164
+ # 0. Create global scale for tensor-group quantization
165
+ if strategy == QuantizationStrategy .TENSOR_GROUP :
153
166
init_global_scale = Parameter (
154
167
torch .empty (1 , dtype = torch .float32 , device = device ),
155
168
requires_grad = False ,
@@ -158,56 +171,54 @@ def _initialize_scale_zero_point(
158
171
module , f"{ base_name } _global_scale" , init_global_scale
159
172
)
160
173
161
- # 2. Infer expected scale/zero point shape
162
- if quantization_args .strategy == QuantizationStrategy .TOKEN :
163
- expected_shape = (1 , 1 )
164
- else :
165
- expected_shape = 1
166
-
167
- if base_name == "weight" and weight_shape is not None :
168
- if quantization_args .strategy == QuantizationStrategy .CHANNEL :
169
- # (output_channels, 1) - only for weights
170
- expected_shape = (weight_shape [0 ], 1 )
171
- elif quantization_args .strategy in (
172
- QuantizationStrategy .TENSOR_GROUP ,
173
- QuantizationStrategy .GROUP ,
174
- ):
175
- # GROUP/TENSOR_GROUP for both weights and activations
176
- num_groups = math .ceil (weight_shape [1 ] / quantization_args .group_size )
177
- expected_shape = (weight_shape [0 ], max (num_groups , 1 ))
178
- elif quantization_args .strategy == QuantizationStrategy .BLOCK :
179
- # For block quantization, scale shape should match number of blocks - only
180
- # for weights
181
- if quantization_args .block_structure is None :
182
- raise ValueError (
183
- "Block quantization requires block_structure to be specified"
184
- )
185
- block_height , block_width = quantization_args .block_structure
186
- rows , cols = weight_shape [- 2 ], weight_shape [- 1 ]
187
- num_rows_blocks = math .ceil (rows / block_height )
188
- num_cols_blocks = math .ceil (cols / block_width )
189
-
190
- # Warn if dimensions don't divide evenly
191
- if rows % block_height != 0 or cols % block_width != 0 :
192
- warnings .warn (
193
- f"Block quantization: tensor shape { weight_shape } does not divide"
194
- f"evenly by block structure { quantization_args .block_structure } . "
195
- f"Some blocks will be incomplete which may affect quantization"
196
- "quality." ,
197
- UserWarning ,
198
- )
199
-
200
- expected_shape = (num_rows_blocks , num_cols_blocks )
201
- elif quantization_args .strategy == QuantizationStrategy .BLOCK :
202
- warnings .warn (
203
- f"BLOCK quantization not supported for { base_name } activations. "
204
- f"Falling back to tensor-level quantization." ,
205
- UserWarning ,
206
- )
207
- expected_shape = 1
174
+ # Skip scale/zp initialization for locally dynamic quantization
175
+ if dynamic == DynamicType .LOCAL :
176
+ return
177
+
178
+ # 1. Infer expected scale/zp shape
179
+ if strategy in (QuantizationStrategy .TENSOR , QuantizationStrategy .TOKEN ):
180
+ expected_shape = (1 ,)
181
+
182
+ elif strategy == QuantizationStrategy .CHANNEL :
183
+ if len (observed_shape ) < 1 :
184
+ raise ValueError ("Channel quant requires at least 1 observed dimension" )
208
185
186
+ expected_shape = (observed_shape [- 1 ], 1 )
187
+
188
+ < << << << HEAD
209
189
# 3. Identify quantization scale and zp dtype
210
190
scale_dtype = module .weight .dtype
191
+ == == == =
192
+ elif strategy in (QuantizationStrategy .GROUP , QuantizationStrategy .TENSOR_GROUP ):
193
+ assert quantization_args .group_size is not None
194
+ if len (observed_shape ) < 1 :
195
+ raise ValueError ("Group quant requires at least 1 observed dimension" )
196
+
197
+ group_size = quantization_args .group_size
198
+ num_groups = _strict_divide (observed_shape [- 1 ], group_size , strategy )
199
+ expected_shape = (num_groups , group_size )
200
+
201
+ # initialize activation ordering if applicable
202
+ if actorder == ActivationOrdering .GROUP :
203
+ init_g_idx = Parameter (
204
+ torch .full ((observed_shape [- 1 ],), - 1 , device = device , dtype = torch .int ),
205
+ requires_grad = False ,
206
+ )
207
+ register_offload_parameter (module , f"{ base_name } _g_idx" , init_g_idx )
208
+
209
+ elif strategy == QuantizationStrategy .BLOCK :
210
+ assert quantization_args .block_structure is not None
211
+ if len (observed_shape ) < 2 :
212
+ raise ValueError ("Block quant requires at least 2 observed dimensions" )
213
+
214
+ block_structure = quantization_args .block_structure
215
+ num_rows = _strict_divide (observed_shape [- 2 ], block_structure [- 2 ], strategy )
216
+ num_cols = _strict_divide (observed_shape [- 1 ], block_structure [- 1 ], strategy )
217
+ expected_shape = (num_rows , num_cols )
218
+
219
+ # 2. Identify quantization scale and zp dtype
220
+ scale_dtype = observed_dtype
221
+ >> >> >> > fde779c (refactor )
211
222
212
223
if is_fp4 (quantization_args = quantization_args ):
213
224
scale_dtype = zp_dtype = FP8_E4M3_DATA .dtype
@@ -223,14 +234,12 @@ def _initialize_scale_zero_point(
223
234
scale_dtype = torch .bfloat16
224
235
zp_dtype = quantization_args .pytorch_dtype ()
225
236
226
- # 4. Initializes empty scale, zero point, and g_idx parameters for the module
227
- # do not init scales for quantzation_args.dynamic == DynamicType.local
228
- if not quantization_args .dynamic :
229
- init_scale = Parameter (
230
- torch .empty (expected_shape , dtype = scale_dtype , device = device ),
231
- requires_grad = False ,
232
- )
233
- register_offload_parameter (module , f"{ base_name } _scale" , init_scale )
237
+ # 3. Initializes scale/zp for the module
238
+ init_scale = Parameter (
239
+ torch .empty (expected_shape , dtype = scale_dtype , device = device ),
240
+ requires_grad = False ,
241
+ )
242
+ register_offload_parameter (module , f"{ base_name } _scale" , init_scale )
234
243
235
244
if force_zero_point or not quantization_args .symmetric :
236
245
init_zero_point = Parameter (
@@ -239,16 +248,6 @@ def _initialize_scale_zero_point(
239
248
)
240
249
register_offload_parameter (module , f"{ base_name } _zero_point" , init_zero_point )
241
250
242
- # only grouped activation ordering has g_idx
243
- if quantization_args .actorder == ActivationOrdering .GROUP :
244
- g_idx_shape = (weight_shape [1 ],)
245
- g_idx_dtype = torch .int
246
- init_g_idx = Parameter (
247
- torch .full (g_idx_shape , - 1 , device = device , dtype = g_idx_dtype ),
248
- requires_grad = False ,
249
- )
250
- register_offload_parameter (module , f"{ base_name } _g_idx" , init_g_idx )
251
-
252
251
253
252
def _initialize_attn_scales (module : Module ) -> None :
254
253
"""Initlaize k_scale, v_scale for self_attn"""
@@ -270,3 +269,16 @@ def _initialize_attn_scales(module: Module) -> None:
270
269
requires_grad = False ,
271
270
)
272
271
register_offload_parameter (module , KVCacheScaleType .VALUE .value , init_scale )
272
+
273
+
274
+ def _strict_divide (observed : int , divisor : int , strategy : QuantizationStrategy ) -> int :
275
+ out = observed // divisor
276
+ if out * divisor != observed :
277
+ raise ValueError (
278
+ f"{ strategy } quantization strategy requires strict division of "
279
+ f"weight/activation size { observed } and group/block size { divisor } . "
280
+ "consider reducing the group/block size or ignoring modules with weights "
281
+ f"not divisible by { divisor } "
282
+ )
283
+
284
+ return out
0 commit comments