Skip to content

Commit e764e79

Browse files
committed
sync awq act scale
Signed-off-by: Jennifer Chen <[email protected]>
1 parent faff217 commit e764e79

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,21 +79,22 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
7979
if not distributed_sync:
8080
return
8181

82-
def sync_quantizer_amax_across_dp(quantizer, parallel_state):
82+
def sync_quantizer_amax_across_dp_cp(quantizer, parallel_state):
8383
"""Synchronize the amax across all ranks in the data parallel and context parallel groups."""
8484
if isinstance(quantizer, SequentialQuantizer):
8585
for _q in quantizer:
86-
sync_quantizer_amax_across_dp(_q, parallel_state)
86+
sync_quantizer_amax_across_dp_cp(_q, parallel_state)
8787
return
8888
if getattr(quantizer, "_amax", None) is not None:
8989
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
90+
quantizer.sync_amax_across_distributed_group(parallel_state.context_parallel_group)
9091
# TODO: create sync_bias_across_distributed_group
9192

9293
for name, module in model.named_modules():
9394
if isinstance(module, QuantModule):
9495
for child in module.children():
9596
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
96-
sync_quantizer_amax_across_dp(child, module.parallel_state)
97+
sync_quantizer_amax_across_dp_cp(child, module.parallel_state)
9798
# TP sync:
9899
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same
99100

@@ -624,13 +625,23 @@ def forward(self, input, *args, **kwargs):
624625
# This will also perform distributed amax sync for input_quantizers
625626
max_calibrate(model, lambda model: None)
626627

628+
def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_group):
629+
# Sync across Data Parallel (DP)
630+
if data_parallel_group.is_initialized():
631+
dist.all_reduce(module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group)
632+
# Sync across Context Parallel (CP)
633+
if context_parallel_group.is_initialized():
634+
dist.all_reduce(module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=context_parallel_group.group)
635+
627636
for name, module in model.named_modules():
628637
if (
629638
is_quantized_linear(module)
630639
and hasattr(module, "awq_lite")
631640
and module.awq_lite.num_cache_steps > 0
632641
):
633642
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps
643+
sync_act_scale_across_dp_cp(module, module.parallel_state.data_parallel_group, module.parallel_state.context_parallel_group)
644+
634645
# Hack: MoEs forward all tokens through all experts if _if_calib is True
635646
module._if_calib = True
636647

0 commit comments

Comments
 (0)