[M4] feat: Add M4 end2end support and qwen3 examples#2011
Conversation
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
# Conflicts: # src/megatron/bridge/training/eval.py # src/megatron/bridge/training/gpt_step.py
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
# Conflicts: # src/megatron/bridge/training/tensor_inspect.py # src/megatron/bridge/training/train.py # tests/unit_tests/models/test_gpt_full_te_layer_autocast_spec.py
|
/ok to test 7161f84 |
|
Hi, @yaoyu-33, can you add subscription "x. Optimizer Updates" in pr readme. |
cuichenx
left a comment
There was a problem hiding this comment.
Gloo process groups are not supported (NCCL only)
Should this be added as a assertion in code?
Please check out skyw's comment above as well
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
📝 WalkthroughWalkthroughThis PR introduces decentralized process group (PG) support to Megatron-Bridge via a ProcessGroupCollection abstraction. It adds configuration options, new HyperCommGrid-based initialization paths, example scripts, and propagates pg_collection throughout training workflows including initialization, model setup, checkpointing, optimization, batching, and training loops. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant Initialize
participant HyperCommGrid
participant PGCollection as ProcessGroupCollection
participant Model
participant Optimizer
participant Training
Client->>Initialize: initialize_megatron(cfg)
Initialize->>Initialize: Check use_decentralized_pg flag
alt Decentralized PG Enabled
Initialize->>HyperCommGrid: Create grid with TP/PP/CP/DP sizes
HyperCommGrid->>PGCollection: Build collection with all groups
PGCollection->>Initialize: Return ProcessGroupCollection
else Standard MPU Path
Initialize->>Initialize: Use mpu initialization
Initialize->>PGCollection: Create via use_mpu_process_groups()
end
Initialize->>Initialize: Set random seeds with pg_collection ranks
Initialize-->>Client: Return ProcessGroupCollection
Client->>Model: provide_distributed_model(cfg, pg_collection=pg_collection)
Model->>Model: Skip MPU init (pg_collection provided)
Model->>Model: Create model with pg_collection
Model-->>Client: Return model
Client->>Optimizer: setup_optimizer(cfg, pg_collection=pg_collection)
Optimizer->>Optimizer: Create optimizer with pg_collection
Optimizer-->>Client: Return optimizer
Client->>Training: train_step(batch, pg_collection)
Training->>Training: Create P2PCommunicator(pg_collection.pp)
Training->>Training: Call forward_backward_func(pp_size, vp_size, p2p_communicator, pg_collection)
Training->>Training: Use pg_collection for group operations
Training-->>Client: Step complete
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
tests/unit_tests/training/test_vlm_step.py (1)
95-105: Monkeypatched CP slicing helper should acceptcp_groupkwarg
get_batch_on_this_cp_rankis now called withcp_group=..., but the monkeypatch only accepts a positional arg and will error. Allowcp_group/**kwargsto keep the tests green.🐛 Proposed fix
- monkeypatch.setattr( - "megatron.core.utils.get_batch_on_this_cp_rank", - lambda x: x, - raising=True, - ) + monkeypatch.setattr( + "megatron.core.utils.get_batch_on_this_cp_rank", + lambda x, **_: x, + raising=True, + )- monkeypatch.setattr("megatron.core.utils.get_batch_on_this_cp_rank", lambda x: x, raising=True) + monkeypatch.setattr("megatron.core.utils.get_batch_on_this_cp_rank", lambda x, **_: x, raising=True)Also applies to: 166-168
src/megatron/bridge/training/checkpointing.py (1)
560-579: Avoid serializing ProcessGroup objects in checkpoint metadata.
sharded_sd_metadatais also passed ascontent_metadata; persistingpg_collection.dp_cpcan break serialization. Consider keeping a separate metadata dict for sharding only.🛠️ Proposed fix
- sharded_sd_metadata = _build_sharded_state_dict_metadata(cfg.optimizer.use_distributed_optimizer, ckpt_cfg) - sharded_sd_metadata["dp_cp_group"] = pg_collection.dp_cp + sharded_sd_metadata = _build_sharded_state_dict_metadata(cfg.optimizer.use_distributed_optimizer, ckpt_cfg) + sharding_sd_metadata = dict(sharded_sd_metadata) + sharding_sd_metadata["dp_cp_group"] = pg_collection.dp_cp ... - optim_sd_kwargs=dict(metadata=sharded_sd_metadata), - model_sd_kwargs=dict(metadata=sharded_sd_metadata), + optim_sd_kwargs=dict(metadata=sharding_sd_metadata), + model_sd_kwargs=dict(metadata=sharding_sd_metadata),src/megatron/bridge/training/llava_step.py (1)
88-134: Fix assertion for last‑stage‑only pipelines.
Whenis_lastis true andis_firstis false, tokens/input_ids are intentionally omitted, but the unconditional assertion fails. Make the checks stage‑specific.🛠️ Proposed fix
- assert batch.get("tokens") is not None or batch.get("input_ids") is not None, "tokens or input_ids must be present" + if is_first: + assert ( + batch.get("tokens") is not None or batch.get("input_ids") is not None + ), "tokens or input_ids must be present on the first PP stage" + if is_last: + assert ( + batch.get("labels") is not None and batch.get("loss_mask") is not None + ), "labels/loss_mask must be present on the last PP stage"src/megatron/bridge/training/initialize.py (1)
60-163: Docstrings should reflectProcessGroupCollectionreturn values.
Bothinitialize_megatronandtorch_dist_initcan now return aProcessGroupCollection, but the docs still state “callable or None.”Docstring update
- Returns: - An optional callable to finish MPU initialization if lazy_mpu_init is True, - otherwise None. + Returns: + If lazy init is enabled, returns a callable that completes initialization and returns + a ProcessGroupCollection; otherwise returns the ProcessGroupCollection directly.- Returns: - An optional callable to finish MPU initialization if skip_mpu_initialization - or lazy_mpu_init is True, otherwise None. + Returns: + If lazy init is enabled, returns a callable that completes initialization and returns + a ProcessGroupCollection; otherwise returns the ProcessGroupCollection directly.
🤖 Fix all issues with AI agents
In `@examples/recipes/decentralized_pg/pretrain_qwen3_with_decentralized_pg.py`:
- Around line 555-603: The variable test_data_iterator returned from
setup_data_iterators is unused and will trigger RUF059; rename or prefix it to
_test_data_iterator (or otherwise discard it) where it's assigned to silence the
lint warning — update the assignment that currently reads "train_data_iterator,
valid_data_iterator, test_data_iterator = setup_data_iterators(...)" to use
"train_data_iterator, valid_data_iterator, _test_data_iterator =
setup_data_iterators(...)" so callers (and tools) know the test iterator is
intentionally unused.
In `@examples/recipes/decentralized_pg/README.md`:
- Around line 154-168: The fenced code block under the "HyperCommGrid Explained"
section (the rank layout snippet) is missing a language tag; update the
triple-backtick fence to include a language such as "text" (i.e., change ``` to
```text) so markdownlint MD040 is satisfied and the snippet is treated as plain
text; the block is the one showing "World Size = 8, Shape = [2, 1, 2, 2] means:"
and the subsequent rank layout for HyperCommGrid.
In `@src/megatron/bridge/training/initialize.py`:
- Around line 436-454: The decentralized path always builds
embedding_rank_lists/pos_embedding_rank_lists from the default first/last
PP-stage logic, ignoring the initialize_megatron hooks; modify
_create_pg_collection to accept and use the get_embedding_ranks and
get_position_embedding_ranks callbacks (or their returned lists) when provided,
populating embedding_rank_lists and pos_embedding_rank_lists from those
callbacks instead of always using the current first/last logic, and only fall
back to the existing logic that builds
embedding_rank_lists/pos_embedding_rank_lists from pp_rank_lists when the
callbacks are None; ensure embd_pg and pos_embd_pg are still created via
torch.distributed.new_subgroups_by_enumeration using the final lists so custom
layouts are honored.
- Around line 556-557: The unconditional RuntimeError on device_count==0 should
be restricted to only the workflows that require CUDA; replace the unconditional
check with a guarded check that raises only when no CUDA devices are available
AND the code is initializing decentralized/NCCL groups — e.g. change the if to
something like: if device_count == 0 and (decentralized or is_decentralized) and
(backend == 'nccl' or dist_backend == 'nccl'): raise RuntimeError(...). Update
the check near the parallel-group initialization so it references the existing
variables device_count, decentralized/is_decentralized and backend/dist_backend.
In `@src/megatron/bridge/training/train.py`:
- Around line 253-256: The call to get_forward_backward_func may receive vp_size
as None because config.model.virtual_pipeline_model_parallel_size can be unset;
guard it by defaulting vp_size to 1 before constructing the schedule (e.g.,
compute vp_size = config.model.virtual_pipeline_model_parallel_size or 1) and
then call get_forward_backward_func(pp_size=pg_collection.pp.size(),
vp_size=vp_size) so the schedule construction never gets a None vp_size.
🧹 Nitpick comments (2)
src/megatron/bridge/training/checkpointing.py (1)
1026-1028: Silence the unused pg_collection parameter.
_generate_model_state_dictdoesn’t usepg_collection, but call sites pass it. Consider marking it intentionally unused to satisfy Ruff.♻️ Suggested tweak
def _generate_model_state_dict( model: list[MegatronModule], model_sd_kwargs: Optional[dict[str, Any]] = None, ckpt_format: str = "torch_dist", *, pg_collection: ProcessGroupCollection | None = None, ) -> dict[str, ShardedStateDict]: + _ = pg_collection # intentionally unused (signature parity)Also applies to: 1093-1095
src/megatron/bridge/training/eval.py (1)
213-225: Consider reusing P2PCommunicator instance.A second
P2PCommunicatoris created here for the non-loss data collection path. Since both instances use identical parameters (pp_group=pg_collection.pp, config=model_config), consider reusing the first instance by moving its creation before the main evaluation loop.However, this is a minor optimization and the current implementation is functionally correct.
examples/recipes/decentralized_pg/pretrain_qwen3_with_decentralized_pg.py
Show resolved
Hide resolved
| # Embedding and position-embedding groups | ||
| embd_pg = None | ||
| pos_embd_pg = None | ||
| # Enumerate ranks per PP group | ||
| pp_rank_lists = grid._gen_rank_enum(["pp"]) | ||
| # Determine embedding ranks for each pp group | ||
| embedding_rank_lists: list[list[int]] = [] | ||
| pos_embedding_rank_lists: list[list[int]] = [] | ||
| for ranks in pp_rank_lists: | ||
| if not ranks: | ||
| continue | ||
| # embedding_ranks: first and last pp stage (or only one if pp_size==1) | ||
| embedding_rank_lists.append([ranks[0]] if len(ranks) == 1 else [ranks[0], ranks[-1]]) | ||
| pos_embedding_rank_lists.append([ranks[0]]) | ||
| if embedding_rank_lists: | ||
| embd_pg, _ = torch.distributed.new_subgroups_by_enumeration(embedding_rank_lists, backend="nccl") | ||
| if pos_embedding_rank_lists: | ||
| pos_embd_pg, _ = torch.distributed.new_subgroups_by_enumeration(pos_embedding_rank_lists, backend="nccl") | ||
|
|
There was a problem hiding this comment.
Custom embedding-rank hooks are ignored in the decentralized path.
initialize_megatron accepts get_embedding_ranks / get_position_embedding_ranks, but _create_pg_collection always uses default first/last-stage logic, which changes behavior for custom layouts. This is a functional gap from the MPU path.
Wire embedding rank hooks into _create_pg_collection
-def _create_pg_collection(
- model_config: GPTModelProvider | T5ModelProvider, num_distributed_optimizer_instances: int
-) -> ProcessGroupCollection:
+def _create_pg_collection(
+ model_config: GPTModelProvider | T5ModelProvider,
+ num_distributed_optimizer_instances: int,
+ *,
+ get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
+ get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
+) -> ProcessGroupCollection:
@@
- for ranks in pp_rank_lists:
+ for ranks in pp_rank_lists:
if not ranks:
continue
- # embedding_ranks: first and last pp stage (or only one if pp_size==1)
- embedding_rank_lists.append([ranks[0]] if len(ranks) == 1 else [ranks[0], ranks[-1]])
- pos_embedding_rank_lists.append([ranks[0]])
+ if get_embedding_ranks is not None:
+ embedding_rank_lists.append(get_embedding_ranks(ranks, len(ranks)))
+ else:
+ embedding_rank_lists.append([ranks[0]] if len(ranks) == 1 else [ranks[0], ranks[-1]])
+ if get_position_embedding_ranks is not None:
+ pos_embedding_rank_lists.append(get_position_embedding_ranks(ranks, len(ranks)))
+ else:
+ pos_embedding_rank_lists.append([ranks[0]])- pg_collection = _create_pg_collection(model_config, num_distributed_optimizer_instances)
+ pg_collection = _create_pg_collection(
+ model_config,
+ num_distributed_optimizer_instances,
+ get_embedding_ranks=get_embedding_ranks,
+ get_position_embedding_ranks=get_position_embedding_ranks,
+ )🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/initialize.py` around lines 436 - 454, The
decentralized path always builds embedding_rank_lists/pos_embedding_rank_lists
from the default first/last PP-stage logic, ignoring the initialize_megatron
hooks; modify _create_pg_collection to accept and use the get_embedding_ranks
and get_position_embedding_ranks callbacks (or their returned lists) when
provided, populating embedding_rank_lists and pos_embedding_rank_lists from
those callbacks instead of always using the current first/last logic, and only
fall back to the existing logic that builds
embedding_rank_lists/pos_embedding_rank_lists from pp_rank_lists when the
callbacks are None; ensure embd_pg and pos_embd_pg are still created via
torch.distributed.new_subgroups_by_enumeration using the final lists so custom
layouts are honored.
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test b6912a7 |
…it tests Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test b6a3c14 |
# Conflicts: # src/megatron/bridge/training/checkpointing.py
|
/ok to test f5495b6 |
|
/ok to test 6e11df4 |
|
/ok to test d0cb1bf |
Summary
This PR introduces Local Parallel Groups support for Megatron-Bridge, enabling the use of
ProcessGroupCollectionpassed through functions instead of relying on Megatron-Core's global parallel state (mpu) variables.Key Features
1. New Config Flag
use_decentralized_pgflag toDistributedInitConfigProcessGroupCollectionusingHyperCommGridinstead of mpu globals2. ProcessGroupCollection Propagation
train.py,gpt_step.py,vlm_step.py) to passpg_collectionexplicitlyparallel_statedependencies from training loop3. Model Provider Updates
pg_collectionparameter toget_model()andprovide_distributed_model()4. Checkpointing Updates
ProcessGroupCollectioninstead of mpu globals5. Data Loading
setup_data_iteratorsto acceptdp_groupparameter for data sharding6. Optimizer Updates
The optimizer initialization now accepts
pg_collectionas an init parameter whenuse_decentralized_pg=True:setup_optimizer()pg_collectionparameterget_megatron_optimizer()pg_collectionto underlying optimizerget_megatron_muon_optimizer()pg_collectionto Muon optimizerThis change enables proper optimizer operation with decentralized process groups. When
use_decentralized_pg=True, the optimizer uses the providedpg_collectioninstead of relying onparallel_state(MPU) directly.Note: Gloo process groups are not supported with decentralized process groups (NCCL only). This limitation is enforced via assertion in setup.
Examples
Added examples in
examples/recipes/local_parallel_groups/:pretrain_qwen3_simple.pyuse_decentralized_pg=Truepretrain_qwen3_with_local_parallel_groups.pyHyperCommGridandProcessGroupCollectioncreationREADME.mdUsage
Simple Approach
Use Cases
Testing
tests/unit_tests/training/test_local_parallel_groups.pytests/functional_tests/training/test_local_parallel_groups.py# Run example torchrun --nproc_per_node=8 examples/recipes/local_parallel_groups/pretrain_qwen3_simple.pyLimitations
Summary by CodeRabbit
Release Notes
New Features
use_decentralized_pgconfiguration option, enabling explicit control over distributed training topology usingProcessGroupCollection.Documentation
Examples
Tests
✏️ Tip: You can customize this high-level summary in your review settings.