@@ -79,21 +79,22 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
79
79
if not distributed_sync :
80
80
return
81
81
82
- def sync_quantizer_amax_across_dp (quantizer , parallel_state ):
82
+ def sync_quantizer_amax_across_dp_cp (quantizer , parallel_state ):
83
83
"""Synchronize the amax across all ranks in the data parallel and context parallel groups."""
84
84
if isinstance (quantizer , SequentialQuantizer ):
85
85
for _q in quantizer :
86
- sync_quantizer_amax_across_dp (_q , parallel_state )
86
+ sync_quantizer_amax_across_dp_cp (_q , parallel_state )
87
87
return
88
88
if getattr (quantizer , "_amax" , None ) is not None :
89
89
quantizer .sync_amax_across_distributed_group (parallel_state .data_parallel_group )
90
+ quantizer .sync_amax_across_distributed_group (parallel_state .context_parallel_group )
90
91
# TODO: create sync_bias_across_distributed_group
91
92
92
93
for name , module in model .named_modules ():
93
94
if isinstance (module , QuantModule ):
94
95
for child in module .children ():
95
96
if isinstance (child , (TensorQuantizer , SequentialQuantizer )):
96
- sync_quantizer_amax_across_dp (child , module .parallel_state )
97
+ sync_quantizer_amax_across_dp_cp (child , module .parallel_state )
97
98
# TP sync:
98
99
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same
99
100
@@ -624,13 +625,23 @@ def forward(self, input, *args, **kwargs):
624
625
# This will also perform distributed amax sync for input_quantizers
625
626
max_calibrate (model , lambda model : None )
626
627
628
+ def sync_act_scale_across_dp_cp (module , data_parallel_group , context_parallel_group ):
629
+ # Sync across Data Parallel (DP)
630
+ if data_parallel_group .is_initialized ():
631
+ dist .all_reduce (module .awq_lite .act_scale , op = dist .ReduceOp .AVG , group = data_parallel_group .group )
632
+ # Sync across Context Parallel (CP)
633
+ if context_parallel_group .is_initialized ():
634
+ dist .all_reduce (module .awq_lite .act_scale , op = dist .ReduceOp .AVG , group = context_parallel_group .group )
635
+
627
636
for name , module in model .named_modules ():
628
637
if (
629
638
is_quantized_linear (module )
630
639
and hasattr (module , "awq_lite" )
631
640
and module .awq_lite .num_cache_steps > 0
632
641
):
633
642
module .awq_lite .act_scale = module .awq_lite .act_scale / module .awq_lite .num_cache_steps
643
+ sync_act_scale_across_dp_cp (module , module .parallel_state .data_parallel_group , module .parallel_state .context_parallel_group )
644
+
634
645
# Hack: MoEs forward all tokens through all experts if _if_calib is True
635
646
module ._if_calib = True
636
647
0 commit comments