@@ -119,21 +119,20 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N
119119 mto .restore_from_modelopt_state (model_ref , state_dict )
120120
121121
122- def _reduce_quantizer_attr (quantizer , attr : str , op = dist .ReduceOp .MAX , group = None ):
123- quantizer_attr = getattr (quantizer , attr ).clone ()
124- print ("quantizer.attr before reduce" , getattr (quantizer , attr ))
125- dist .all_reduce (quantizer_attr , op = op , group = group )
126- print ("quantizer.attr after reduce" , getattr (quantizer , attr ))
127- print ("quantizer_attr after reduce" , quantizer_attr )
122+ def _distributed_attr_check (quantizer , attr : str , op = dist .ReduceOp .MAX , groups = []):
123+ for group in groups :
124+ if group is not None :
125+ quantizer_attr = getattr (quantizer , attr ).clone ()
126+ dist .all_reduce (quantizer_attr , op = op , group = group )
128127 assert torch .allclose (quantizer_attr , getattr (quantizer , attr ))
129128
130129
131130original_awq_lite = model_calib_module .awq_lite
132131
133132
134- def _debug_awq_lite (model , forward_loop , alpha_step = 0.1 , debug = True ):
133+ def _debug_awq_lite (model , forward_loop , alpha_step = 0.1 , debug = True , ** kwargs ):
135134 """Function to mock awq_lite function to always use debug=True for testing"""
136- return original_awq_lite (model , forward_loop , alpha_step , debug = True )
135+ return original_awq_lite (model , forward_loop , alpha_step , debug = True , ** kwargs )
137136
138137
139138@patch ("modelopt.torch.quantization.model_calib.awq_lite" , side_effect = _debug_awq_lite )
@@ -151,125 +150,101 @@ def forward_loop(model):
151150
152151 if config in [mtq .INT8_DEFAULT_CFG , mtq .FP8_DEFAULT_CFG , mtq .INT8_SMOOTHQUANT_CFG ]:
153152 # Lets check the amax for row parallel input quantizer; it should be the same across all tp ranks
154- _reduce_quantizer_attr (model .fc2 .input_quantizer , "amax" , dist .ReduceOp .MAX , group = tp_group )
153+ _distributed_attr_check (
154+ model .fc2 .input_quantizer , "amax" , dist .ReduceOp .MAX , groups = [tp_group ]
155+ )
155156 # Lets check the row parallel weight amax; it should be the same across all tp ranks
156- _reduce_quantizer_attr (
157- model .fc2 .weight_quantizer , "amax" , dist .ReduceOp .MAX , group = tp_group
157+ _distributed_attr_check (
158+ model .fc2 .weight_quantizer , "amax" , dist .ReduceOp .MAX , groups = [ tp_group ]
158159 )
159160
160161 if config in [mtq .INT8_SMOOTHQUANT_CFG , mtq .INT4_AWQ_CFG , mtq .W4A8_AWQ_BETA_CFG ]:
161162 # Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks
162163 input_quantizer = model .fc1 .input_quantizer
163- _reduce_quantizer_attr (
164- input_quantizer , "pre_quant_scale" , dist .ReduceOp .MAX , group = tp_group
164+ _distributed_attr_check (
165+ input_quantizer , "pre_quant_scale" , dist .ReduceOp .MAX , groups = [ tp_group ]
165166 )
166167
167168 if config in [mtq .INT4_AWQ_CFG , mtq .W4A8_AWQ_BETA_CFG ]:
168169 # Check activation scale for AWQ lite
169- _reduce_quantizer_attr (
170+ _distributed_attr_check (
170171 model .fc1 .awq_lite ,
171172 "act_scale" ,
172173 dist .ReduceOp .AVG ,
173- group = tp_group ,
174+ groups = [ tp_group ] ,
174175 )
175176
176177 dist .destroy_process_group ()
177178
178179
179180@patch ("modelopt.torch.quantization.model_calib.awq_lite" , side_effect = _debug_awq_lite )
180- def dp_cp_parallel_test_helper (model , config , group , mock_awq_lite ):
181- calib_data = model .get_dummy_input ().cuda ()
182-
183- def forward_loop (model ):
184- model (calib_data )
185-
186- model = mtq .quantize (model , config , forward_loop )
187-
188- # Sanity check
189- forward_loop (model )
190-
191- # Input quantizer amax
192- if config not in [mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG , mtq .INT4_AWQ_CFG ]:
193- _reduce_quantizer_attr (model .fc1 .input_quantizer , "amax" , dist .ReduceOp .MAX , group = group )
194- _reduce_quantizer_attr (model .fc2 .input_quantizer , "amax" , dist .ReduceOp .MAX , group = group )
195-
196- # Weight quantizer amax
197- if isinstance (model .fc1 .weight_quantizer , SequentialQuantizer ):
198- for quantizer in model .fc1 .weight_quantizer :
199- _reduce_quantizer_attr (quantizer , "amax" , dist .ReduceOp .MAX , group = group )
200- else :
201- _reduce_quantizer_attr (model .fc1 .weight_quantizer , "amax" , dist .ReduceOp .MAX , group = group )
202- if isinstance (model .fc2 .weight_quantizer , SequentialQuantizer ):
203- for quantizer in model .fc2 .weight_quantizer :
204- _reduce_quantizer_attr (quantizer , "amax" , dist .ReduceOp .MAX , group = group )
205- else :
206- _reduce_quantizer_attr (model .fc2 .weight_quantizer , "amax" , dist .ReduceOp .MAX , group = group )
207-
208- if config in [mtq .INT4_AWQ_CFG , mtq .W4A8_AWQ_BETA_CFG ]:
209- # Check act scale
210- _reduce_quantizer_attr (
211- model .fc1 .awq_lite ,
212- "act_scale" ,
213- dist .ReduceOp .AVG ,
214- group = group ,
215- )
216- _reduce_quantizer_attr (
217- model .fc2 .awq_lite ,
218- "act_scale" ,
219- dist .ReduceOp .AVG ,
220- group = group ,
221- )
222-
223-
224- @patch ("modelopt.torch.quantization.model_calib.awq_lite" , side_effect = _debug_awq_lite )
225- def data_tensor_context_parallel_test_helper (model , config , dp_group , tp_group , mock_awq_lite ):
226- # Calib data should be same across each DP rank
181+ def data_tensor_context_parallel_test_helper (
182+ model , config , mock_awq_lite , dp_group = None , tp_group = None
183+ ):
184+ # Calib data should be different across each DP rank
227185 dp_rank = dist .get_rank (group = dp_group )
228186 calib_data = model .get_dummy_input (seed = dp_rank ).cuda ()
229187
188+ if tp_group is not None :
189+ # The input to first layer, the column parallel should be the same across all tp ranks
190+ dist .all_reduce (calib_data , op = dist .ReduceOp .AVG , group = tp_group )
191+
230192 def forward_loop (model ):
231193 model (calib_data )
232194
233195 model = mtq .quantize (model , config , forward_loop )
234196
235- def _reduce_quantizer_attr (quantizer , attr = str , op = dist .ReduceOp .MAX ):
236- quantizer_attr = getattr (quantizer , attr ).clone ()
237-
238- # Perform all-reduce operations
239- dist .all_reduce (quantizer_attr , op = op , group = tp_group )
240-
241- dist .all_reduce (quantizer_attr , op = op , group = dp_group )
242-
243- assert torch .allclose (quantizer_attr , getattr (quantizer , attr )), getattr (quantizer , attr )
244-
245197 # Input quantizer amax
246198 if config not in [mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG , mtq .INT4_AWQ_CFG ]:
247- _reduce_quantizer_attr (model .fc1 .input_quantizer , "amax" , dist .ReduceOp .MAX )
248- _reduce_quantizer_attr (model .fc2 .input_quantizer , "amax" , dist .ReduceOp .MAX )
199+ _distributed_attr_check (
200+ model .fc1 .input_quantizer , "amax" , dist .ReduceOp .MAX , groups = [dp_group , tp_group ]
201+ )
202+ _distributed_attr_check (
203+ model .fc2 .input_quantizer , "amax" , dist .ReduceOp .MAX , groups = [dp_group , tp_group ]
204+ )
249205
250206 # Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks
251207 # Channel-wise (INT8) only expects same amax across row parallel ranks
252208 # Block-wise quantization does not expect same amax across row and column parallel ranks
253209 if config in [mtq .FP8_DEFAULT_CFG , mtq .NVFP4_DEFAULT_CFG ]:
254210 if isinstance (model .fc1 .weight_quantizer , SequentialQuantizer ):
255211 for quantizer in model .fc1 .weight_quantizer :
256- _reduce_quantizer_attr (quantizer , "amax" , dist .ReduceOp .MAX )
212+ _distributed_attr_check (
213+ quantizer , "amax" , dist .ReduceOp .MAX , groups = [dp_group , tp_group ]
214+ )
257215 else :
258- _reduce_quantizer_attr (model .fc1 .weight_quantizer , "amax" , dist .ReduceOp .MAX )
259-
260- if config in [mtq .FP8_DEFAULT_CFG , mtq .NVFP4_DEFAULT_CFG , mtq .INT8_DEFAULT_CFG ]:
216+ _distributed_attr_check (
217+ model .fc1 .weight_quantizer , "amax" , dist .ReduceOp .MAX , groups = [dp_group , tp_group ]
218+ )
219+
220+ if config in [
221+ mtq .FP8_DEFAULT_CFG ,
222+ mtq .NVFP4_DEFAULT_CFG ,
223+ mtq .INT8_DEFAULT_CFG ,
224+ mtq .INT8_SMOOTHQUANT_CFG ,
225+ ]:
261226 if isinstance (model .fc2 .weight_quantizer , SequentialQuantizer ):
262227 for quantizer in model .fc2 .weight_quantizer :
263- _reduce_quantizer_attr (quantizer , "amax" , dist .ReduceOp .MAX )
228+ _distributed_attr_check (
229+ quantizer , "amax" , dist .ReduceOp .MAX , groups = [dp_group , tp_group ]
230+ )
264231 else :
265- _reduce_quantizer_attr (model .fc2 .weight_quantizer , "amax" , dist .ReduceOp .MAX )
232+ _distributed_attr_check (
233+ model .fc2 .weight_quantizer , "amax" , dist .ReduceOp .MAX , groups = [dp_group , tp_group ]
234+ )
235+
236+ # Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks
237+ # It is different across DP/CP ranks since the input is different
238+ if tp_group and config in [mtq .INT8_SMOOTHQUANT_CFG , mtq .INT4_AWQ_CFG , mtq .W4A8_AWQ_BETA_CFG ]:
239+ input_quantizer = model .fc1 .input_quantizer
240+ _distributed_attr_check (
241+ input_quantizer , "pre_quant_scale" , dist .ReduceOp .MAX , groups = [dp_group , tp_group ]
242+ )
266243
267244 # Check act scale
268245 if config in [mtq .INT4_AWQ_CFG , mtq .W4A8_AWQ_BETA_CFG ]:
269- _reduce_quantizer_attr (
270- model .fc1 .awq_lite ,
271- "act_scale" ,
272- dist .ReduceOp .AVG ,
246+ _distributed_attr_check (
247+ model .fc1 .awq_lite , "act_scale" , dist .ReduceOp .AVG , groups = [dp_group , tp_group ]
273248 )
274249
275250
0 commit comments