@@ -119,21 +119,20 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N
119119        mto .restore_from_modelopt_state (model_ref , state_dict )
120120
121121
122- def  _reduce_quantizer_attr (quantizer , attr : str , op = dist .ReduceOp .MAX , group = None ):
123-     quantizer_attr  =  getattr (quantizer , attr ).clone ()
124-     print ("quantizer.attr before reduce" , getattr (quantizer , attr ))
125-     dist .all_reduce (quantizer_attr , op = op , group = group )
126-     print ("quantizer.attr after reduce" , getattr (quantizer , attr ))
127-     print ("quantizer_attr after reduce" , quantizer_attr )
122+ def  _distributed_attr_check (quantizer , attr : str , op = dist .ReduceOp .MAX , groups = []):
123+     for  group  in  groups :
124+         if  group  is  not   None :
125+             quantizer_attr  =  getattr (quantizer , attr ).clone ()
126+             dist .all_reduce (quantizer_attr , op = op , group = group )
128127    assert  torch .allclose (quantizer_attr , getattr (quantizer , attr ))
129128
130129
131130original_awq_lite  =  model_calib_module .awq_lite 
132131
133132
134- def  _debug_awq_lite (model , forward_loop , alpha_step = 0.1 , debug = True ):
133+ def  _debug_awq_lite (model , forward_loop , alpha_step = 0.1 , debug = True ,  ** kwargs ):
135134    """Function to mock awq_lite function to always use debug=True for testing""" 
136-     return  original_awq_lite (model , forward_loop , alpha_step , debug = True )
135+     return  original_awq_lite (model , forward_loop , alpha_step , debug = True ,  ** kwargs )
137136
138137
139138@patch ("modelopt.torch.quantization.model_calib.awq_lite" , side_effect = _debug_awq_lite ) 
@@ -151,125 +150,101 @@ def forward_loop(model):
151150
152151    if  config  in  [mtq .INT8_DEFAULT_CFG , mtq .FP8_DEFAULT_CFG , mtq .INT8_SMOOTHQUANT_CFG ]:
153152        # Lets check the amax for row parallel input quantizer; it should be the same across all tp ranks 
154-         _reduce_quantizer_attr (model .fc2 .input_quantizer , "amax" , dist .ReduceOp .MAX , group = tp_group )
153+         _distributed_attr_check (
154+             model .fc2 .input_quantizer , "amax" , dist .ReduceOp .MAX , groups = [tp_group ]
155+         )
155156        # Lets check the row parallel weight amax; it should be the same across all tp ranks 
156-         _reduce_quantizer_attr (
157-             model .fc2 .weight_quantizer , "amax" , dist .ReduceOp .MAX , group = tp_group 
157+         _distributed_attr_check (
158+             model .fc2 .weight_quantizer , "amax" , dist .ReduceOp .MAX , groups = [ tp_group ] 
158159        )
159160
160161    if  config  in  [mtq .INT8_SMOOTHQUANT_CFG , mtq .INT4_AWQ_CFG , mtq .W4A8_AWQ_BETA_CFG ]:
161162        # Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks 
162163        input_quantizer  =  model .fc1 .input_quantizer 
163-         _reduce_quantizer_attr (
164-             input_quantizer , "pre_quant_scale" , dist .ReduceOp .MAX , group = tp_group 
164+         _distributed_attr_check (
165+             input_quantizer , "pre_quant_scale" , dist .ReduceOp .MAX , groups = [ tp_group ] 
165166        )
166167
167168    if  config  in  [mtq .INT4_AWQ_CFG , mtq .W4A8_AWQ_BETA_CFG ]:
168169        # Check activation scale for AWQ lite 
169-         _reduce_quantizer_attr (
170+         _distributed_attr_check (
170171            model .fc1 .awq_lite ,
171172            "act_scale" ,
172173            dist .ReduceOp .AVG ,
173-             group = tp_group ,
174+             groups = [ tp_group ] ,
174175        )
175176
176177    dist .destroy_process_group ()
177178
178179
179180@patch ("modelopt.torch.quantization.model_calib.awq_lite" , side_effect = _debug_awq_lite ) 
180- def  dp_cp_parallel_test_helper (model , config , group , mock_awq_lite ):
181-     calib_data  =  model .get_dummy_input ().cuda ()
182- 
183-     def  forward_loop (model ):
184-         model (calib_data )
185- 
186-     model  =  mtq .quantize (model , config , forward_loop )
187- 
188-     # Sanity check 
189-     forward_loop (model )
190- 
191-     # Input quantizer amax 
192-     if  config  not  in   [mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG , mtq .INT4_AWQ_CFG ]:
193-         _reduce_quantizer_attr (model .fc1 .input_quantizer , "amax" , dist .ReduceOp .MAX , group = group )
194-         _reduce_quantizer_attr (model .fc2 .input_quantizer , "amax" , dist .ReduceOp .MAX , group = group )
195- 
196-     # Weight quantizer amax 
197-     if  isinstance (model .fc1 .weight_quantizer , SequentialQuantizer ):
198-         for  quantizer  in  model .fc1 .weight_quantizer :
199-             _reduce_quantizer_attr (quantizer , "amax" , dist .ReduceOp .MAX , group = group )
200-     else :
201-         _reduce_quantizer_attr (model .fc1 .weight_quantizer , "amax" , dist .ReduceOp .MAX , group = group )
202-     if  isinstance (model .fc2 .weight_quantizer , SequentialQuantizer ):
203-         for  quantizer  in  model .fc2 .weight_quantizer :
204-             _reduce_quantizer_attr (quantizer , "amax" , dist .ReduceOp .MAX , group = group )
205-     else :
206-         _reduce_quantizer_attr (model .fc2 .weight_quantizer , "amax" , dist .ReduceOp .MAX , group = group )
207- 
208-     if  config  in  [mtq .INT4_AWQ_CFG , mtq .W4A8_AWQ_BETA_CFG ]:
209-         # Check act scale 
210-         _reduce_quantizer_attr (
211-             model .fc1 .awq_lite ,
212-             "act_scale" ,
213-             dist .ReduceOp .AVG ,
214-             group = group ,
215-         )
216-         _reduce_quantizer_attr (
217-             model .fc2 .awq_lite ,
218-             "act_scale" ,
219-             dist .ReduceOp .AVG ,
220-             group = group ,
221-         )
222- 
223- 
224- @patch ("modelopt.torch.quantization.model_calib.awq_lite" , side_effect = _debug_awq_lite ) 
225- def  data_tensor_context_parallel_test_helper (model , config , dp_group , tp_group , mock_awq_lite ):
226-     # Calib data should be same across each DP rank 
181+ def  data_tensor_context_parallel_test_helper (
182+     model , config , mock_awq_lite , dp_group = None , tp_group = None 
183+ ):
184+     # Calib data should be different across each DP rank 
227185    dp_rank  =  dist .get_rank (group = dp_group )
228186    calib_data  =  model .get_dummy_input (seed = dp_rank ).cuda ()
229187
188+     if  tp_group  is  not   None :
189+         # The input to first layer, the column parallel should be the same across all tp ranks 
190+         dist .all_reduce (calib_data , op = dist .ReduceOp .AVG , group = tp_group )
191+ 
230192    def  forward_loop (model ):
231193        model (calib_data )
232194
233195    model  =  mtq .quantize (model , config , forward_loop )
234196
235-     def  _reduce_quantizer_attr (quantizer , attr = str , op = dist .ReduceOp .MAX ):
236-         quantizer_attr  =  getattr (quantizer , attr ).clone ()
237- 
238-         # Perform all-reduce operations 
239-         dist .all_reduce (quantizer_attr , op = op , group = tp_group )
240- 
241-         dist .all_reduce (quantizer_attr , op = op , group = dp_group )
242- 
243-         assert  torch .allclose (quantizer_attr , getattr (quantizer , attr )), getattr (quantizer , attr )
244- 
245197    # Input quantizer amax 
246198    if  config  not  in   [mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG , mtq .INT4_AWQ_CFG ]:
247-         _reduce_quantizer_attr (model .fc1 .input_quantizer , "amax" , dist .ReduceOp .MAX )
248-         _reduce_quantizer_attr (model .fc2 .input_quantizer , "amax" , dist .ReduceOp .MAX )
199+         _distributed_attr_check (
200+             model .fc1 .input_quantizer , "amax" , dist .ReduceOp .MAX , groups = [dp_group , tp_group ]
201+         )
202+         _distributed_attr_check (
203+             model .fc2 .input_quantizer , "amax" , dist .ReduceOp .MAX , groups = [dp_group , tp_group ]
204+         )
249205
250206    # Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks 
251207    # Channel-wise (INT8) only expects same amax across row parallel ranks 
252208    # Block-wise quantization does not expect same amax across row and column parallel ranks 
253209    if  config  in  [mtq .FP8_DEFAULT_CFG , mtq .NVFP4_DEFAULT_CFG ]:
254210        if  isinstance (model .fc1 .weight_quantizer , SequentialQuantizer ):
255211            for  quantizer  in  model .fc1 .weight_quantizer :
256-                 _reduce_quantizer_attr (quantizer , "amax" , dist .ReduceOp .MAX )
212+                 _distributed_attr_check (
213+                     quantizer , "amax" , dist .ReduceOp .MAX , groups = [dp_group , tp_group ]
214+                 )
257215        else :
258-             _reduce_quantizer_attr (model .fc1 .weight_quantizer , "amax" , dist .ReduceOp .MAX )
259- 
260-     if  config  in  [mtq .FP8_DEFAULT_CFG , mtq .NVFP4_DEFAULT_CFG , mtq .INT8_DEFAULT_CFG ]:
216+             _distributed_attr_check (
217+                 model .fc1 .weight_quantizer , "amax" , dist .ReduceOp .MAX , groups = [dp_group , tp_group ]
218+             )
219+ 
220+     if  config  in  [
221+         mtq .FP8_DEFAULT_CFG ,
222+         mtq .NVFP4_DEFAULT_CFG ,
223+         mtq .INT8_DEFAULT_CFG ,
224+         mtq .INT8_SMOOTHQUANT_CFG ,
225+     ]:
261226        if  isinstance (model .fc2 .weight_quantizer , SequentialQuantizer ):
262227            for  quantizer  in  model .fc2 .weight_quantizer :
263-                 _reduce_quantizer_attr (quantizer , "amax" , dist .ReduceOp .MAX )
228+                 _distributed_attr_check (
229+                     quantizer , "amax" , dist .ReduceOp .MAX , groups = [dp_group , tp_group ]
230+                 )
264231        else :
265-             _reduce_quantizer_attr (model .fc2 .weight_quantizer , "amax" , dist .ReduceOp .MAX )
232+             _distributed_attr_check (
233+                 model .fc2 .weight_quantizer , "amax" , dist .ReduceOp .MAX , groups = [dp_group , tp_group ]
234+             )
235+ 
236+     # Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks 
237+     # It is different across DP/CP ranks since the input is different 
238+     if  tp_group  and  config  in  [mtq .INT8_SMOOTHQUANT_CFG , mtq .INT4_AWQ_CFG , mtq .W4A8_AWQ_BETA_CFG ]:
239+         input_quantizer  =  model .fc1 .input_quantizer 
240+         _distributed_attr_check (
241+             input_quantizer , "pre_quant_scale" , dist .ReduceOp .MAX , groups = [dp_group , tp_group ]
242+         )
266243
267244    # Check act scale 
268245    if  config  in  [mtq .INT4_AWQ_CFG , mtq .W4A8_AWQ_BETA_CFG ]:
269-         _reduce_quantizer_attr (
270-             model .fc1 .awq_lite ,
271-             "act_scale" ,
272-             dist .ReduceOp .AVG ,
246+         _distributed_attr_check (
247+             model .fc1 .awq_lite , "act_scale" , dist .ReduceOp .AVG , groups = [dp_group , tp_group ]
273248        )
274249
275250
0 commit comments