@@ -79,21 +79,22 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
7979 if not distributed_sync :
8080 return
8181
82- def sync_quantizer_amax_across_dp (quantizer , parallel_state ):
82+ def sync_quantizer_amax_across_dp_cp (quantizer , parallel_state ):
8383 """Synchronize the amax across all ranks in the data parallel and context parallel groups."""
8484 if isinstance (quantizer , SequentialQuantizer ):
8585 for _q in quantizer :
86- sync_quantizer_amax_across_dp (_q , parallel_state )
86+ sync_quantizer_amax_across_dp_cp (_q , parallel_state )
8787 return
8888 if getattr (quantizer , "_amax" , None ) is not None :
8989 quantizer .sync_amax_across_distributed_group (parallel_state .data_parallel_group )
90+ quantizer .sync_amax_across_distributed_group (parallel_state .context_parallel_group )
9091 # TODO: create sync_bias_across_distributed_group
9192
9293 for name , module in model .named_modules ():
9394 if isinstance (module , QuantModule ):
9495 for child in module .children ():
9596 if isinstance (child , (TensorQuantizer , SequentialQuantizer )):
96- sync_quantizer_amax_across_dp (child , module .parallel_state )
97+ sync_quantizer_amax_across_dp_cp (child , module .parallel_state )
9798 # TP sync:
9899 # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same
99100
@@ -624,13 +625,23 @@ def forward(self, input, *args, **kwargs):
624625 # This will also perform distributed amax sync for input_quantizers
625626 max_calibrate (model , lambda model : None )
626627
628+ def sync_act_scale_across_dp_cp (module , data_parallel_group , context_parallel_group ):
629+ # Sync across Data Parallel (DP)
630+ if data_parallel_group .is_initialized ():
631+ dist .all_reduce (module .awq_lite .act_scale , op = dist .ReduceOp .AVG , group = data_parallel_group .group )
632+ # Sync across Context Parallel (CP)
633+ if context_parallel_group .is_initialized ():
634+ dist .all_reduce (module .awq_lite .act_scale , op = dist .ReduceOp .AVG , group = context_parallel_group .group )
635+
627636 for name , module in model .named_modules ():
628637 if (
629638 is_quantized_linear (module )
630639 and hasattr (module , "awq_lite" )
631640 and module .awq_lite .num_cache_steps > 0
632641 ):
633642 module .awq_lite .act_scale = module .awq_lite .act_scale / module .awq_lite .num_cache_steps
643+ sync_act_scale_across_dp_cp (module , module .parallel_state .data_parallel_group , module .parallel_state .context_parallel_group )
644+
634645 # Hack: MoEs forward all tokens through all experts if _if_calib is True
635646 module ._if_calib = True
636647
0 commit comments