@@ -117,6 +117,12 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N
117
117
mto .restore_from_modelopt_state (model_ref , state_dict )
118
118
119
119
120
+ def _reduce_quantizer_attr (quantizer , attr = str , op = dist .ReduceOp .MAX , group = None ):
121
+ quantizer_attr = getattr (quantizer , attr ).clone ()
122
+ dist .all_reduce (quantizer_attr , op = op , group = group )
123
+ assert torch .allclose (quantizer_attr , getattr (quantizer , attr ))
124
+
125
+
120
126
def tensor_parallel_test_helper (model , config , tp_group ):
121
127
# The input to first layer, the column parallel should be the same across all tp ranks
122
128
calib_data = model .get_dummy_input ().cuda ()
@@ -126,27 +132,39 @@ def forward_loop(model):
126
132
model (calib_data )
127
133
128
134
model = mtq .quantize (model , config , forward_loop )
129
-
130
135
# Sanity check
131
136
forward_loop (model )
132
137
133
138
if config in [mtq .INT8_DEFAULT_CFG , mtq .FP8_DEFAULT_CFG , mtq .INT8_SMOOTHQUANT_CFG ]:
134
139
# Lets check the amax for row parallel input quantizer; it should be the same across all tp ranks
135
- activation_amax = model .fc2 .input_quantizer .amax .clone ()
136
- dist .all_reduce (activation_amax , op = dist .ReduceOp .MAX , group = tp_group )
137
- assert torch .allclose (activation_amax , model .fc2 .input_quantizer .amax )
140
+ _reduce_quantizer_attr (model .fc2 .input_quantizer , "amax" , dist .ReduceOp .MAX , group = tp_group )
138
141
139
142
# Lets check the row parallel weight amax; it should be the same across all tp ranks
140
- weight_amax = model . fc2 . weight_quantizer . amax . clone ()
141
- dist . all_reduce ( weight_amax , op = dist .ReduceOp .MAX , group = tp_group )
142
- assert torch . allclose ( weight_amax , model . fc2 . weight_quantizer . amax )
143
+ _reduce_quantizer_attr (
144
+ model . fc2 . weight_quantizer , "amax" , dist .ReduceOp .MAX , group = tp_group
145
+ )
143
146
144
147
if config in [mtq .INT8_SMOOTHQUANT_CFG , mtq .INT4_AWQ_CFG , mtq .W4A8_AWQ_BETA_CFG ]:
145
148
# Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks
146
149
input_quantizer = model .fc1 .input_quantizer
147
- pre_quant_scale = input_quantizer .pre_quant_scale .clone ()
148
- dist .all_reduce (pre_quant_scale , op = dist .ReduceOp .MAX , group = tp_group )
149
- assert torch .allclose (pre_quant_scale , input_quantizer .pre_quant_scale )
150
+ _reduce_quantizer_attr (
151
+ input_quantizer , "pre_quant_scale" , dist .ReduceOp .MAX , group = tp_group
152
+ )
153
+
154
+ if config in [mtq .INT4_AWQ_CFG , mtq .W4A8_AWQ_BETA_CFG ]:
155
+ # Check act scale
156
+ _reduce_quantizer_attr (
157
+ model .fc1 .weight_quantizer .awq_lite .act_scale ,
158
+ "act_scale" ,
159
+ dist .ReduceOp .AVG ,
160
+ group = tp_group ,
161
+ )
162
+ _reduce_quantizer_attr (
163
+ model .fc2 .weight_quantizer .awq_lite .act_scale ,
164
+ "act_scale" ,
165
+ dist .ReduceOp .AVG ,
166
+ group = tp_group ,
167
+ )
150
168
151
169
dist .destroy_process_group ()
152
170
@@ -159,27 +177,37 @@ def forward_loop(model):
159
177
160
178
model = mtq .quantize (model , config , forward_loop )
161
179
162
- def reduce_amax (quantizer ):
163
- amax = quantizer .amax .clone ()
164
- dist .all_reduce (amax , op = dist .ReduceOp .MAX , group = group )
165
- assert torch .allclose (amax , quantizer .amax )
166
-
167
180
# Input quantizer amax
168
181
if config not in [mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG , mtq .INT4_AWQ_CFG ]:
169
- reduce_amax (model .fc1 .input_quantizer )
170
- reduce_amax (model .fc2 .input_quantizer )
182
+ _reduce_quantizer_attr (model .fc1 .input_quantizer , "amax" , dist . ReduceOp . MAX , group = group )
183
+ _reduce_quantizer_attr (model .fc2 .input_quantizer , "amax" , dist . ReduceOp . MAX , group = group )
171
184
172
185
# Weight quantizer amax
173
186
if isinstance (model .fc1 .weight_quantizer , SequentialQuantizer ):
174
187
for quantizer in model .fc1 .weight_quantizer :
175
- reduce_amax (quantizer )
188
+ _reduce_quantizer_attr (quantizer , "amax" , dist . ReduceOp . MAX , group = group )
176
189
else :
177
- reduce_amax (model .fc1 .weight_quantizer )
190
+ _reduce_quantizer_attr (model .fc1 .weight_quantizer , "amax" , dist . ReduceOp . MAX , group = group )
178
191
if isinstance (model .fc2 .weight_quantizer , SequentialQuantizer ):
179
192
for quantizer in model .fc2 .weight_quantizer :
180
- reduce_amax (quantizer )
193
+ _reduce_quantizer_attr (quantizer , "amax" , dist . ReduceOp . MAX , group = group )
181
194
else :
182
- reduce_amax (model .fc2 .weight_quantizer )
195
+ _reduce_quantizer_attr (model .fc2 .weight_quantizer , "amax" , dist .ReduceOp .MAX , group = group )
196
+
197
+ if config in [mtq .INT4_AWQ_CFG , mtq .W4A8_AWQ_BETA_CFG ]:
198
+ # Check act scale
199
+ _reduce_quantizer_attr (
200
+ model .fc1 .weight_quantizer .awq_lite .act_scale ,
201
+ "act_scale" ,
202
+ dist .ReduceOp .AVG ,
203
+ group = group ,
204
+ )
205
+ _reduce_quantizer_attr (
206
+ model .fc2 .weight_quantizer .awq_lite .act_scale ,
207
+ "act_scale" ,
208
+ dist .ReduceOp .AVG ,
209
+ group = group ,
210
+ )
183
211
184
212
185
213
def data_tensor_context_parallel_test_helper (model , config , dp_group , tp_group , cp_group ):
@@ -192,33 +220,52 @@ def forward_loop(model):
192
220
193
221
model = mtq .quantize (model , config , forward_loop )
194
222
195
- def reduce_amax (quantizer ):
196
- amax = quantizer . amax .clone ()
197
- print ("amax before reduce" , amax )
198
- print ("quantizer.amax before reduce" , quantizer . amax )
199
- dist .all_reduce (amax , op = dist . ReduceOp . MAX , group = dp_group )
200
- dist .all_reduce (amax , op = dist . ReduceOp . MAX , group = cp_group )
201
- dist .all_reduce (amax , op = dist . ReduceOp . MAX , group = tp_group )
202
- print ("amax after reduce" , amax )
203
- print ("quantizer.amax after reduce" , quantizer . amax )
204
- assert torch .allclose (amax , quantizer . amax )
223
+ def _reduce_quantizer_attr (quantizer , attr = str , op = dist . ReduceOp . MAX ):
224
+ quantizer_attr = getattr ( quantizer , attr ) .clone ()
225
+ print ("quantizer_attr before reduce" , quantizer_attr )
226
+ print ("quantizer.attr before reduce" , getattr ( quantizer , attr ) )
227
+ dist .all_reduce (quantizer_attr , op = op , group = dp_group )
228
+ dist .all_reduce (quantizer_attr , op = op , group = cp_group )
229
+ dist .all_reduce (quantizer_attr , op = op , group = tp_group )
230
+ print ("quantizer_attr after reduce" , quantizer_attr )
231
+ print ("quantizer.attr after reduce" , getattr ( quantizer , attr ) )
232
+ assert torch .allclose (quantizer_attr , getattr ( quantizer , attr ) )
205
233
206
234
# Input quantizer amax
207
235
if config not in [mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG , mtq .INT4_AWQ_CFG ]:
208
- reduce_amax (model .fc1 .input_quantizer )
209
- reduce_amax (model .fc2 .input_quantizer )
236
+ _reduce_quantizer_attr (model .fc1 .input_quantizer , "amax" , dist . ReduceOp . MAX , group = dp_group )
237
+ _reduce_quantizer_attr (model .fc2 .input_quantizer , "amax" , dist . ReduceOp . MAX , group = dp_group )
210
238
211
239
if isinstance (model .fc1 .weight_quantizer , SequentialQuantizer ):
212
240
for quantizer in model .fc1 .weight_quantizer :
213
- reduce_amax (quantizer )
241
+ _reduce_quantizer_attr (quantizer , "amax" , dist . ReduceOp . MAX , group = dp_group )
214
242
else :
215
- reduce_amax (model .fc1 .weight_quantizer )
243
+ _reduce_quantizer_attr (
244
+ model .fc1 .weight_quantizer , "amax" , dist .ReduceOp .MAX , group = dp_group
245
+ )
216
246
217
247
if isinstance (model .fc2 .weight_quantizer , SequentialQuantizer ):
218
248
for quantizer in model .fc2 .weight_quantizer :
219
- reduce_amax (quantizer )
249
+ _reduce_quantizer_attr (quantizer , "amax" , dist . ReduceOp . MAX , group = dp_group )
220
250
else :
221
- reduce_amax (model .fc2 .weight_quantizer )
251
+ _reduce_quantizer_attr (
252
+ model .fc2 .weight_quantizer , "amax" , dist .ReduceOp .MAX , group = dp_group
253
+ )
254
+
255
+ # Check act scale
256
+ if config in [mtq .INT4_AWQ_CFG , mtq .W4A8_AWQ_BETA_CFG ]:
257
+ _reduce_quantizer_attr (
258
+ model .fc1 .weight_quantizer .awq_lite .act_scale ,
259
+ "act_scale" ,
260
+ dist .ReduceOp .AVG ,
261
+ group = tp_group ,
262
+ )
263
+ _reduce_quantizer_attr (
264
+ model .fc2 .weight_quantizer .awq_lite .act_scale ,
265
+ "act_scale" ,
266
+ dist .ReduceOp .AVG ,
267
+ group = tp_group ,
268
+ )
222
269
223
270
224
271
def auto_quantize_helper (model ):
0 commit comments