@@ -79,22 +79,21 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
7979 if not distributed_sync :
8080 return
8181
82- def sync_quantizer_amax_across_dp_cp (quantizer , parallel_state ):
82+ def sync_quantizer_amax_across_dp (quantizer , parallel_state ):
8383 """Synchronize the amax across all ranks in the data parallel and context parallel groups."""
8484 if isinstance (quantizer , SequentialQuantizer ):
8585 for _q in quantizer :
86- sync_quantizer_amax_across_dp_cp (_q , parallel_state )
86+ sync_quantizer_amax_across_dp (_q , parallel_state )
8787 return
8888 if getattr (quantizer , "_amax" , None ) is not None :
8989 quantizer .sync_amax_across_distributed_group (parallel_state .data_parallel_group )
90- quantizer .sync_amax_across_distributed_group (parallel_state .context_parallel_group )
9190 # TODO: create sync_bias_across_distributed_group
9291
9392 for name , module in model .named_modules ():
9493 if isinstance (module , QuantModule ):
9594 for child in module .children ():
9695 if isinstance (child , (TensorQuantizer , SequentialQuantizer )):
97- sync_quantizer_amax_across_dp_cp (child , module .parallel_state )
96+ sync_quantizer_amax_across_dp (child , module .parallel_state )
9897 # TP sync:
9998 # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same
10099
0 commit comments