@@ -79,22 +79,21 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
79
79
if not distributed_sync :
80
80
return
81
81
82
- def sync_quantizer_amax_across_dp_cp (quantizer , parallel_state ):
82
+ def sync_quantizer_amax_across_dp (quantizer , parallel_state ):
83
83
"""Synchronize the amax across all ranks in the data parallel and context parallel groups."""
84
84
if isinstance (quantizer , SequentialQuantizer ):
85
85
for _q in quantizer :
86
- sync_quantizer_amax_across_dp_cp (_q , parallel_state )
86
+ sync_quantizer_amax_across_dp (_q , parallel_state )
87
87
return
88
88
if getattr (quantizer , "_amax" , None ) is not None :
89
89
quantizer .sync_amax_across_distributed_group (parallel_state .data_parallel_group )
90
- quantizer .sync_amax_across_distributed_group (parallel_state .context_parallel_group )
91
90
# TODO: create sync_bias_across_distributed_group
92
91
93
92
for name , module in model .named_modules ():
94
93
if isinstance (module , QuantModule ):
95
94
for child in module .children ():
96
95
if isinstance (child , (TensorQuantizer , SequentialQuantizer )):
97
- sync_quantizer_amax_across_dp_cp (child , module .parallel_state )
96
+ sync_quantizer_amax_across_dp (child , module .parallel_state )
98
97
# TP sync:
99
98
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same
100
99
0 commit comments