Skip to content

Commit 75df0f2

Browse files
author
Jennifer Chen
committed
update dp group
Signed-off-by: Jennifer Chen <[email protected]>
1 parent db7155c commit 75df0f2

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,22 +79,21 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
7979
if not distributed_sync:
8080
return
8181

82-
def sync_quantizer_amax_across_dp_cp(quantizer, parallel_state):
82+
def sync_quantizer_amax_across_dp(quantizer, parallel_state):
8383
"""Synchronize the amax across all ranks in the data parallel and context parallel groups."""
8484
if isinstance(quantizer, SequentialQuantizer):
8585
for _q in quantizer:
86-
sync_quantizer_amax_across_dp_cp(_q, parallel_state)
86+
sync_quantizer_amax_across_dp(_q, parallel_state)
8787
return
8888
if getattr(quantizer, "_amax", None) is not None:
8989
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
90-
quantizer.sync_amax_across_distributed_group(parallel_state.context_parallel_group)
9190
# TODO: create sync_bias_across_distributed_group
9291

9392
for name, module in model.named_modules():
9493
if isinstance(module, QuantModule):
9594
for child in module.children():
9695
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
97-
sync_quantizer_amax_across_dp_cp(child, module.parallel_state)
96+
sync_quantizer_amax_across_dp(child, module.parallel_state)
9897
# TP sync:
9998
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same
10099

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import megatron.core.transformer.mlp as megatron_mlp
2424
import torch
2525
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
26+
from megatron.core.parallel_state import get_data_parallel_group
2627
from megatron.core.transformer import MegatronModule
2728
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
2829
from megatron.core.utils import get_tensor_model_parallel_group_if_none
@@ -217,8 +218,13 @@ class _MegatronParallelLinear(_ParallelLinear):
217218
]
218219

219220
def _setup(self):
221+
data_parallel_group = None
222+
try:
223+
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
224+
except:
225+
data_parallel_group = get_data_parallel_group()
220226
self.parallel_state = ParallelState(
221-
getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(),
227+
data_parallel_group,
222228
mcore_parallel.get_tensor_model_parallel_group(),
223229
mcore_parallel.get_context_parallel_group(),
224230
)

0 commit comments

Comments
 (0)