Skip to content

Commit c3e1d2d

Browse files
wdykasrootrootrootroot
authored
Fixing PG routing for inference & training separation (#2485)
Co-authored-by: root <root@gpu-h100-0435.cm.cluster> Co-authored-by: root <root@gpu-h100-0012.cm.cluster> Co-authored-by: root <root@gpu-h100-0426.cm.cluster> Co-authored-by: root <root@gpu-h100-0188.cm.cluster> Co-authored-by: root <root@gpu-h100-0013.cm.cluster> Co-authored-by: root <root@gpu-h100-0032.cm.cluster>
1 parent e79d9a8 commit c3e1d2d

File tree

11 files changed

+151
-42
lines changed

11 files changed

+151
-42
lines changed

gpt_builders.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# NOTE: Loading `megatron.legacy.model` earlier fails due to circular import
2222

2323

24-
def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None):
24+
def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_collection=None):
2525
print_rank_0('building GPT model ...')
2626
if config is None:
2727
if args.yaml_cfg is not None:
@@ -93,6 +93,7 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None):
9393
rope_scaling=args.use_rope_scaling,
9494
mtp_block_spec=mtp_block_spec,
9595
vp_stage=vp_stage,
96+
pg_collection=pg_collection,
9697
)
9798

9899
return model

mamba_builders.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from megatron.training.arguments import core_transformer_config_from_args
99

1010

11-
def mamba_builder(args, pre_process, post_process, vp_stage=None, config=None):
11+
def mamba_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_collection=None):
1212
print_rank_0('building MAMBA model ...')
1313
if config is None:
1414
config = core_transformer_config_from_args(args, TransformerConfig)
@@ -35,6 +35,7 @@ def mamba_builder(args, pre_process, post_process, vp_stage=None, config=None):
3535
position_embedding_type=args.position_embedding_type,
3636
rotary_percent=args.rotary_percent,
3737
rotary_base=args.rotary_base,
38+
pg_collection=pg_collection,
3839
)
3940

4041
for l in range(model.decoder.num_layers_per_pipeline_rank):

megatron/core/inference/engines/dynamic_engine.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from torch import Tensor
2020
from torch.cuda.nvtx import range_pop, range_push
2121

22-
from megatron.core import parallel_state
2322
from megatron.core.inference.contexts.dynamic_context import (
2423
DynamicInferenceContext,
2524
MaxSequenceLengthOverflowError,
@@ -40,8 +39,16 @@
4039
TextGenerationController,
4140
)
4241
from megatron.core.inference.utils import Counter, await_process_event
42+
from megatron.core.process_groups_config import ProcessGroupCollection
4343
from megatron.core.transformer.cuda_graphs import delete_cuda_graphs
44-
from megatron.core.utils import get_asyncio_loop, internal_api, trace_async_exceptions
44+
from megatron.core.utils import (
45+
get_asyncio_loop,
46+
get_pg_rank,
47+
get_pg_size,
48+
get_pg_src_rank,
49+
internal_api,
50+
trace_async_exceptions,
51+
)
4552

4653
try:
4754
from tqdm import tqdm
@@ -136,6 +143,7 @@ def __init__(
136143
track_paused_request_events: bool = False,
137144
enable_chunked_prefill: bool = True,
138145
inference_logging_step_interval: int = 0,
146+
pg_collection: Optional[ProcessGroupCollection] = None,
139147
):
140148

141149
assert isinstance(
@@ -159,6 +167,11 @@ def __init__(
159167
controller.inference_wrapped_model.model.config.enable_cuda_graph
160168
)
161169

170+
if pg_collection is not None:
171+
self.pg_collection = pg_collection
172+
else:
173+
self.pg_collection = ProcessGroupCollection.use_mpu_process_groups()
174+
162175
# Initialization options.
163176
self.controller = controller
164177
self.context = context
@@ -378,15 +391,15 @@ async def start_listening_to_data_parallel_coordinator(
378391
self.zmq_sockets = [] # keep track of all sockets created by this engine
379392

380393
# Get world info.
381-
dp_group = parallel_state.get_data_parallel_group()
382-
dp_src = parallel_state.get_data_parallel_src_rank()
383-
dp_size = parallel_state.get_data_parallel_world_size()
384-
dp_rank = parallel_state.get_data_parallel_rank()
394+
dp_group = self.pg_collection.dp
395+
dp_src = get_pg_src_rank(dp_group)
396+
dp_size = get_pg_size(self.pg_collection.dp)
397+
dp_rank = get_pg_rank(self.pg_collection.dp)
385398

386-
mp_group = parallel_state.get_model_parallel_group()
387-
mp_src = parallel_state.get_model_parallel_src_rank()
388-
tp_rank = parallel_state.get_tensor_model_parallel_rank()
389-
pp_rank = parallel_state.get_pipeline_model_parallel_rank()
399+
mp_group = self.pg_collection.mp
400+
mp_src = get_pg_src_rank(mp_group)
401+
tp_rank = get_pg_rank(self.pg_collection.tp)
402+
pp_rank = get_pg_rank(self.pg_collection.pp)
390403

391404
self.is_mp_coordinator = tp_rank == 0 and pp_rank == 0
392405
self.is_dp_coordinator = (dp_rank == 0) and self.is_mp_coordinator
@@ -400,7 +413,7 @@ async def start_listening_to_data_parallel_coordinator(
400413
args=(
401414
coordinator_ready_event,
402415
inference_coordinator_port,
403-
parallel_state.get_data_parallel_world_size(),
416+
get_pg_size(self.pg_collection.dp),
404417
),
405418
)
406419
self.inference_coordinator_process.start()

megatron/core/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,23 @@ def get_pg_rank(group=None):
541541
return group.rank()
542542

543543

544+
def get_pg_src_rank(group=None):
545+
"""Calculate the global rank corresponding to the first local rank
546+
in the given process group.
547+
548+
Args:
549+
group: Process group to query. If None or distributed is not initialized,
550+
returns 0.
551+
552+
Returns:
553+
int: The first (source) global rank in the group.
554+
"""
555+
if not torch.distributed.is_initialized() or group is None:
556+
return 0
557+
ranks = torch.distributed.get_process_group_ranks(group)
558+
return ranks[0]
559+
560+
544561
def get_attr_wrapped_model(model, attr, allow_none=True, return_model_obj=False):
545562
"""Get an attribute from a wrapped model.
546563
If return_model_obj is true, return the object that has the 'attr' attribute;

megatron/rl/inference/megatron.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pydantic import PrivateAttr
99

1010
from megatron.core import parallel_state
11+
from megatron.core.utils import get_attr_wrapped_model
1112
from megatron.core.inference.contexts.dynamic_context import DynamicInferenceContext
1213
from megatron.core.inference.engines.abstract_engine import AbstractEngine
1314
from megatron.core.inference.engines.dynamic_engine import DynamicInferenceEngine
@@ -26,7 +27,11 @@
2627
from megatron.core.models.gpt.gpt_model import GPTModel
2728
from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols
2829
from megatron.core.transformer.module import MegatronModule
29-
from megatron.core.utils import get_mamba_inference_state_config_from_model, log_single_rank
30+
from megatron.core.pipeline_parallel.utils import (
31+
is_pp_first_stage,
32+
is_pp_last_stage,
33+
)
34+
from megatron.core.utils import get_mamba_inference_state_config_from_model, log_single_rank, get_pg_size
3035
from megatron.training import get_wandb_writer
3136
from megatron.training.global_vars import get_args, get_tokenizer
3237

@@ -109,6 +114,16 @@ def get_dynamic_inference_engine(args: Namespace, model: MegatronModule, inferen
109114

110115
mamba_inference_state_config = get_mamba_inference_state_config_from_model(model)
111116

117+
# DynamicInferenceContext must use the inference model's TP size, not the
118+
# training TP size from global args. The inference model may have a custom
119+
# ProcessGroupCollection with a different TP size.
120+
pg_collection = get_attr_wrapped_model(model, "pg_collection")
121+
tp_group = getattr(pg_collection, 'tp', None) if pg_collection is not None else None
122+
if tp_group is not None:
123+
inference_tp_size = get_pg_size(tp_group)
124+
else:
125+
inference_tp_size = args.tensor_model_parallel_size
126+
112127
# Inference context.
113128
inference_context = DynamicInferenceContext(
114129
params_dtype=args.params_dtype,
@@ -126,7 +141,7 @@ def get_dynamic_inference_engine(args: Namespace, model: MegatronModule, inferen
126141
block_size_tokens=args.inference_dynamic_batching_block_size,
127142
buffer_size_gb=args.inference_dynamic_batching_buffer_size_gb,
128143
max_tokens=args.inference_dynamic_batching_max_tokens,
129-
tensor_model_parallel_size=args.tensor_model_parallel_size,
144+
tensor_model_parallel_size=inference_tp_size,
130145
materialize_only_last_token_logits=True,
131146
mamba_inference_state_config=mamba_inference_state_config,
132147
cache_mla_latent=args.multi_latent_attention and args.cache_mla_latents,
@@ -143,7 +158,7 @@ def get_dynamic_inference_engine(args: Namespace, model: MegatronModule, inferen
143158
inference_wrapped_model = GPTInferenceWrapper(model, args, inference_context)
144159

145160
inference_wrapped_model.model_is_pipeline_parallel = not (
146-
parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage()
161+
is_pp_first_stage(pg_collection.pp) and is_pp_last_stage(pg_collection.pp)
147162
)
148163

149164
text_generation_controller = SimpleTextGenerationController(
@@ -156,6 +171,7 @@ def get_dynamic_inference_engine(args: Namespace, model: MegatronModule, inferen
156171
enable_cuda_graph=enable_cuda_graph,
157172
random_seed=args.seed,
158173
inference_logging_step_interval=inference_logging_step_interval,
174+
pg_collection=pg_collection,
159175
)
160176

161177

megatron/training/training.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,18 @@
5252
from megatron.core.utils import (
5353
check_param_hashes_across_dp_replicas,
5454
get_model_config,
55+
get_pg_size,
56+
get_pg_rank,
5557
StragglerDetector,
5658
)
5759
from megatron.core.fp8_utils import correct_amax_history_if_needed
60+
from megatron.core.process_groups_config import ProcessGroupCollection
61+
from megatron.core.pipeline_parallel.utils import (
62+
is_pp_first_stage,
63+
is_pp_last_stage,
64+
is_vp_first_stage,
65+
is_vp_last_stage,
66+
)
5867
from megatron.training.checkpointing import load_checkpoint
5968
from megatron.training.checkpointing import save_checkpoint
6069
from megatron.training.checkpointing import checkpoint_exists
@@ -873,10 +882,12 @@ def update_train_iters(args):
873882
print_rank_0(f'setting training iterations to {args.train_iters}')
874883

875884

876-
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
885+
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True, config=None, pg_collection=None):
877886
"""Build the model."""
878887
args = get_args()
879888
args.model_type = model_type
889+
if pg_collection is None:
890+
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
880891

881892
if has_nvidia_modelopt:
882893
from megatron.post_training.checkpointing import has_modelopt_state
@@ -893,23 +904,38 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
893904
# Build model.
894905
def build_model():
895906
if (
896-
mpu.get_pipeline_model_parallel_world_size() > 1
907+
get_pg_size(pg_collection.pp) > 1
897908
and args.virtual_pipeline_model_parallel_size is not None
898909
):
899910
model = []
900-
for i in range(args.virtual_pipeline_model_parallel_size):
911+
vp_size = args.virtual_pipeline_model_parallel_size
912+
for i in range(vp_size):
901913
# Set pre_process and post_process only after virtual rank is set.
902-
pre_process = mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=i)
903-
post_process = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i)
914+
pre_process = is_pp_first_stage(pg_collection.pp) and is_vp_first_stage(
915+
vp_stage=i, vp_size=vp_size
916+
)
917+
post_process = is_pp_last_stage(pg_collection.pp) and is_vp_last_stage(
918+
vp_stage=i, vp_size=vp_size
919+
)
904920
this_model = model_provider_func(
905-
pre_process=pre_process, post_process=post_process, vp_stage=i)
921+
pre_process=pre_process,
922+
post_process=post_process,
923+
vp_stage=i,
924+
config=config,
925+
pg_collection=pg_collection,
926+
)
906927
this_model.model_type = model_type
907928
this_model.vp_stage = i
908929
model.append(this_model)
909930
else:
910-
pre_process = mpu.is_pipeline_first_stage()
911-
post_process = mpu.is_pipeline_last_stage()
912-
model = model_provider_func(pre_process=pre_process, post_process=post_process)
931+
pre_process = is_pp_first_stage(pg_collection.pp)
932+
post_process = is_pp_last_stage(pg_collection.pp)
933+
model = model_provider_func(
934+
pre_process=pre_process,
935+
post_process=post_process,
936+
config=config,
937+
pg_collection=pg_collection,
938+
)
913939
model.model_type = model_type
914940
return model
915941

@@ -934,12 +960,12 @@ def build_model():
934960
num_parameters = sum(
935961
[sum([p.nelement() for p in model_module.parameters()]) for model_module in model]
936962
)
937-
if mpu.get_data_parallel_rank() == 0 and mpu.get_context_parallel_rank() == 0:
963+
if get_pg_rank(pg_collection.dp) == 0 and get_pg_rank(pg_collection.cp) == 0:
938964
print(
939965
' > number of parameters on (tensor, pipeline) '
940966
'model parallel rank ({}, {}): {}'.format(
941-
mpu.get_tensor_model_parallel_rank(),
942-
mpu.get_pipeline_model_parallel_rank(),
967+
get_pg_rank(pg_collection.tp),
968+
get_pg_rank(pg_collection.pp),
943969
num_parameters,
944970
),
945971
flush=True,

model_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
def model_provider(
25-
model_builder: Callable, pre_process=True, post_process=True, vp_stage: Optional[int] = None
25+
model_builder: Callable, pre_process=True, post_process=True, vp_stage: Optional[int] = None, config=None, pg_collection=None,
2626
) -> Union[GPTModel, megatron.legacy.model.GPTModel, MambaModel]:
2727
"""Builds the model.
2828
@@ -64,7 +64,7 @@ def oom_observer(device, alloc, device_alloc, device_free):
6464
# [ModelOpt]: Use custom builder + spec when modelopt is enabled
6565
model_builder = modelopt_gpt_mamba_builder
6666

67-
return model_builder(args, pre_process, post_process, vp_stage)
67+
return model_builder(args, pre_process, post_process, vp_stage, config=config, pg_collection=pg_collection)
6868

6969

7070
def count_parameters_in_layer(model, layer_name):

tests/unit_tests/dist_checkpointing/test_optimizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,11 @@ def initialize_real_model(
276276
virtual_pipeline_model_parallel_size=None,
277277
**config_kwargs,
278278
):
279+
# These kwargs are passed through training.get_model for model construction,
280+
# but are not part of TransformerConfig; strip them before building config.
281+
config_kwargs.pop("pg_collection", None)
282+
config_kwargs.pop("config", None)
283+
279284
torch.manual_seed(seed)
280285
model_parallel_cuda_manual_seed(seed)
281286

tests/unit_tests/dist_checkpointing/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
12
from functools import partial
23
from typing import Any, Callable, Tuple, Union
34
from unittest import mock
@@ -24,6 +25,11 @@
2425
def initialize_gpt_model(
2526
pre_process=True, post_process=True, seed=0, use_glu=True, **config_kwargs
2627
):
28+
# These kwargs are passed through training.get_model for model construction,
29+
# but are not part of TransformerConfig; strip them before building config.
30+
config_kwargs.pop("pg_collection", None)
31+
config_kwargs.pop("config", None)
32+
2733
torch.manual_seed(seed)
2834
model_parallel_cuda_manual_seed(seed)
2935

@@ -61,6 +67,11 @@ def initialize_moe_model(
6167
use_grouped_mlp=False,
6268
**config_kwargs,
6369
):
70+
# These kwargs are passed through training.get_model for model construction,
71+
# but are not part of TransformerConfig; strip them before building config.
72+
config_kwargs.pop("pg_collection", None)
73+
config_kwargs.pop("config", None)
74+
6475
torch.manual_seed(seed)
6576
model_parallel_cuda_manual_seed(seed)
6677
expert_num = 8

tests/unit_tests/transformer/test_multi_latent_attention.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,7 +1082,9 @@ def test_parallel_multi_latent_attention_correctness(
10821082
hidden_size = 128
10831083

10841084
# Model initialization function
1085-
def initialize_gpt_model(config, pre_process=True, post_process=True, vp_stage=None):
1085+
def initialize_gpt_model(
1086+
pre_process=True, post_process=True, vp_stage=None, pg_collection=None, config=None
1087+
):
10861088
layer_spec = get_gpt_layer_with_transformer_engine_spec(multi_latent_attention=True)
10871089
gpt_model = GPTModel(
10881090
config=config,
@@ -1141,9 +1143,7 @@ def initialize_gpt_model(config, pre_process=True, post_process=True, vp_stage=N
11411143
init_basic_mock_args(mock_args, 1, 1, bf16=True)
11421144
mock_args.context_parallel_size = 1
11431145
mock_args.sequence_parallel = 1
1144-
gpt_model = unwrap_model(
1145-
get_model(partial(initialize_gpt_model, config=transformer_config))
1146-
)
1146+
gpt_model = unwrap_model(get_model(initialize_gpt_model, config=transformer_config))
11471147

11481148
# Initialize args and save checkpoint
11491149
init_checkpointing_mock_args(mock_args, ckpt_dir, False)
@@ -1178,9 +1178,7 @@ def initialize_gpt_model(config, pre_process=True, post_process=True, vp_stage=N
11781178
init_basic_mock_args(mock_args, tp, 1, bf16=True)
11791179
mock_args.context_parallel_size = cp
11801180
mock_args.sequence_parallel = sp
1181-
gpt_model = unwrap_model(
1182-
get_model(partial(initialize_gpt_model, config=transformer_config))
1183-
)
1181+
gpt_model = unwrap_model(get_model(initialize_gpt_model, config=transformer_config))
11841182
with mock.patch('megatron.training.checkpointing.check_checkpoint_args'):
11851183
with mock.patch('megatron.training.checkpointing.update_num_microbatches'):
11861184
load_checkpoint(gpt_model, None, None)

0 commit comments

Comments
 (0)