@@ -210,103 +210,6 @@ def forward_loop(model):
210210 )
211211
212212
213- @patch ("modelopt.torch.quantization.model_calib.awq_lite" , side_effect = _debug_awq_lite )
214- def dp_cp_parallel_test_helper (model , config , group , mock_awq_lite ):
215- calib_data = model .get_dummy_input ().cuda ()
216-
217- def forward_loop (model ):
218- model (calib_data )
219-
220- model = mtq .quantize (model , config , forward_loop )
221-
222- # Sanity check
223- forward_loop (model )
224-
225- # Input quantizer amax
226- if config not in [mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG , mtq .INT4_AWQ_CFG ]:
227- _reduce_quantizer_attr (model .fc1 .input_quantizer , "amax" , dist .ReduceOp .MAX , group = group )
228- _reduce_quantizer_attr (model .fc2 .input_quantizer , "amax" , dist .ReduceOp .MAX , group = group )
229-
230- # Weight quantizer amax
231- if isinstance (model .fc1 .weight_quantizer , SequentialQuantizer ):
232- for quantizer in model .fc1 .weight_quantizer :
233- _reduce_quantizer_attr (quantizer , "amax" , dist .ReduceOp .MAX , group = group )
234- else :
235- _reduce_quantizer_attr (model .fc1 .weight_quantizer , "amax" , dist .ReduceOp .MAX , group = group )
236- if isinstance (model .fc2 .weight_quantizer , SequentialQuantizer ):
237- for quantizer in model .fc2 .weight_quantizer :
238- _reduce_quantizer_attr (quantizer , "amax" , dist .ReduceOp .MAX , group = group )
239- else :
240- _reduce_quantizer_attr (model .fc2 .weight_quantizer , "amax" , dist .ReduceOp .MAX , group = group )
241-
242- if config in [mtq .INT4_AWQ_CFG , mtq .W4A8_AWQ_BETA_CFG ]:
243- # Check act scale
244- _reduce_quantizer_attr (
245- model .fc1 .awq_lite ,
246- "act_scale" ,
247- dist .ReduceOp .AVG ,
248- group = group ,
249- )
250- _reduce_quantizer_attr (
251- model .fc2 .awq_lite ,
252- "act_scale" ,
253- dist .ReduceOp .AVG ,
254- group = group ,
255- )
256-
257-
258- @patch ("modelopt.torch.quantization.model_calib.awq_lite" , side_effect = _debug_awq_lite )
259- def data_tensor_context_parallel_test_helper (model , config , dp_group , tp_group , mock_awq_lite ):
260- # Calib data should be same across each DP rank
261- dp_rank = dist .get_rank (group = dp_group )
262- calib_data = model .get_dummy_input (seed = dp_rank ).cuda ()
263-
264- def forward_loop (model ):
265- model (calib_data )
266-
267- model = mtq .quantize (model , config , forward_loop )
268-
269- def _reduce_quantizer_attr (quantizer , attr = str , op = dist .ReduceOp .MAX ):
270- quantizer_attr = getattr (quantizer , attr ).clone ()
271-
272- # Perform all-reduce operations
273- dist .all_reduce (quantizer_attr , op = op , group = tp_group )
274-
275- dist .all_reduce (quantizer_attr , op = op , group = dp_group )
276-
277- assert torch .allclose (quantizer_attr , getattr (quantizer , attr )), getattr (quantizer , attr )
278-
279- # Input quantizer amax
280- if config not in [mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG , mtq .INT4_AWQ_CFG ]:
281- _reduce_quantizer_attr (model .fc1 .input_quantizer , "amax" , dist .ReduceOp .MAX )
282- _reduce_quantizer_attr (model .fc2 .input_quantizer , "amax" , dist .ReduceOp .MAX )
283-
284- # Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks
285- # Channel-wise (INT8) only expects same amax across row parallel ranks
286- # Block-wise quantization does not expect same amax across row and column parallel ranks
287- if config in [mtq .FP8_DEFAULT_CFG , mtq .NVFP4_DEFAULT_CFG ]:
288- if isinstance (model .fc1 .weight_quantizer , SequentialQuantizer ):
289- for quantizer in model .fc1 .weight_quantizer :
290- _reduce_quantizer_attr (quantizer , "amax" , dist .ReduceOp .MAX )
291- else :
292- _reduce_quantizer_attr (model .fc1 .weight_quantizer , "amax" , dist .ReduceOp .MAX )
293-
294- if config in [mtq .FP8_DEFAULT_CFG , mtq .NVFP4_DEFAULT_CFG , mtq .INT8_DEFAULT_CFG ]:
295- if isinstance (model .fc2 .weight_quantizer , SequentialQuantizer ):
296- for quantizer in model .fc2 .weight_quantizer :
297- _reduce_quantizer_attr (quantizer , "amax" , dist .ReduceOp .MAX )
298- else :
299- _reduce_quantizer_attr (model .fc2 .weight_quantizer , "amax" , dist .ReduceOp .MAX )
300-
301- # Check act scale
302- if config in [mtq .INT4_AWQ_CFG , mtq .W4A8_AWQ_BETA_CFG ]:
303- _reduce_quantizer_attr (
304- model .fc1 .awq_lite ,
305- "act_scale" ,
306- dist .ReduceOp .AVG ,
307- )
308-
309-
310213def auto_quantize_helper (model ):
311214 model , search_state = mtq .auto_quantize (
312215 model ,
0 commit comments