Skip to content

[M4] feat: Add M4 end2end support and qwen3 examples#2011

Merged
yaoyu-33 merged 94 commits intomainfrom
m4/5_initialize
Jan 28, 2026
Merged

[M4] feat: Add M4 end2end support and qwen3 examples#2011
yaoyu-33 merged 94 commits intomainfrom
m4/5_initialize

Conversation

@yaoyu-33
Copy link
Contributor

@yaoyu-33 yaoyu-33 commented Jan 21, 2026

Summary

This PR introduces Local Parallel Groups support for Megatron-Bridge, enabling the use of ProcessGroupCollection passed through functions instead of relying on Megatron-Core's global parallel state (mpu) variables.

Key Features

1. New Config Flag

  • Added use_decentralized_pg flag to DistributedInitConfig
  • When enabled, creates ProcessGroupCollection using HyperCommGrid instead of mpu globals

2. ProcessGroupCollection Propagation

  • Updated training modules (train.py, gpt_step.py, vlm_step.py) to pass pg_collection explicitly
  • Removed direct parallel_state dependencies from training loop

3. Model Provider Updates

  • Added pg_collection parameter to get_model() and provide_distributed_model()
  • Models receive process groups via explicit parameter instead of global state

4. Checkpointing Updates

  • Refactored to use ProcessGroupCollection instead of mpu globals

5. Data Loading

  • Updated setup_data_iterators to accept dp_group parameter for data sharding

6. Optimizer Updates

The optimizer initialization now accepts pg_collection as an init parameter when use_decentralized_pg=True:

Change Description
setup_optimizer() Now accepts optional pg_collection parameter
get_megatron_optimizer() Passes pg_collection to underlying optimizer
get_megatron_muon_optimizer() Passes pg_collection to Muon optimizer

This change enables proper optimizer operation with decentralized process groups. When use_decentralized_pg=True, the optimizer uses the provided pg_collection instead of relying on parallel_state (MPU) directly.

Note: Gloo process groups are not supported with decentralized process groups (NCCL only). This limitation is enforced via assertion in setup.

Examples

Added examples in examples/recipes/local_parallel_groups/:

File Description
pretrain_qwen3_simple.py Simple: Use a recipe with use_decentralized_pg=True
pretrain_qwen3_with_local_parallel_groups.py Advanced: Manual HyperCommGrid and ProcessGroupCollection creation
README.md Documentation for both approaches

Usage

Simple Approach

cfg = qwen3_4b_pretrain_config(...)
cfg.dist.use_decentralized_pg = True
cfg.dist.use_gloo_process_groups = False
pretrain(config=cfg, forward_step_func=forward_step)

Use Cases

  • Reinforcement Learning: Multiple model instances (policy, value, reference) with different parallelism
  • Multi-Model Pipelines: Complex workflows requiring explicit control over communication
  • Testing/Debugging: Isolated process groups without global state side effects

Testing

  • Unit tests: tests/unit_tests/training/test_local_parallel_groups.py
  • Functional tests: tests/functional_tests/training/test_local_parallel_groups.py
# Run example
torchrun --nproc_per_node=8 examples/recipes/local_parallel_groups/pretrain_qwen3_simple.py

Limitations

  • Gloo process groups are not supported (NCCL only)
  • ModelOpt sharded checkpointing is disabled when using local parallel groups

Summary by CodeRabbit

Release Notes

  • New Features

    • Added decentralized process group support via use_decentralized_pg configuration option, enabling explicit control over distributed training topology using ProcessGroupCollection.
    • Introduced process group collection propagation through model creation, optimizer setup, and training execution.
  • Documentation

    • Added comprehensive README with quick-start guidance for decentralized process group configurations, including simple and advanced manual setups.
  • Examples

    • Added example scripts demonstrating basic and advanced decentralized process group usage with Qwen3 pretraining.
  • Tests

    • Added extensive functional and unit test coverage for decentralized process group feature across various parallelism configurations.

✏️ Tip: You can customize this high-level summary in your review settings.

yaoyu-33 and others added 30 commits October 23, 2025 12:24
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
# Conflicts:
#	src/megatron/bridge/training/eval.py
#	src/megatron/bridge/training/gpt_step.py
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
# Conflicts:
#	src/megatron/bridge/training/tensor_inspect.py
#	src/megatron/bridge/training/train.py
#	tests/unit_tests/models/test_gpt_full_te_layer_autocast_spec.py
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@yaoyu-33
Copy link
Contributor Author

/ok to test 7161f84

@shifangx
Copy link
Contributor

Hi, @yaoyu-33, can you add subscription "x. Optimizer Updates" in pr readme.
Because we did changed optimizer, to let it get pg_collection as init.

Copy link
Contributor

@cuichenx cuichenx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gloo process groups are not supported (NCCL only)
Should this be added as a assertion in code?

Please check out skyw's comment above as well

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 27, 2026

📝 Walkthrough

Walkthrough

This PR introduces decentralized process group (PG) support to Megatron-Bridge via a ProcessGroupCollection abstraction. It adds configuration options, new HyperCommGrid-based initialization paths, example scripts, and propagates pg_collection throughout training workflows including initialization, model setup, checkpointing, optimization, batching, and training loops.

Changes

Cohort / File(s) Summary
Documentation & Examples for Decentralized PG
examples/recipes/decentralized_pg/README.md, examples/recipes/decentralized_pg/pretrain_qwen3_simple.py, examples/recipes/decentralized_pg/pretrain_qwen3_with_decentralized_pg.py
Adds comprehensive README documenting decentralized process groups, quick-start guidance, and two example scripts (simple and advanced) demonstrating Qwen3 pretraining with manual ProcessGroupCollection setup via HyperCommGrid, including rank layout and limitations documentation.
Configuration & Feature Flag
src/megatron/bridge/training/config.py
Adds use_decentralized_pg: bool = False field to DistributedInitConfig with validation preventing Gloo backend when enabled; enforces NCCL requirement for decentralized PG.
Distributed Initialization & Process Group Creation
src/megatron/bridge/training/initialize.py
Introduces _create_pg_collection() function building comprehensive ProcessGroupCollection via HyperCommGrid with TP, CP, DP, PP, MP, EP groups; reworks _initialize_distributed() to support decentralized path; updates return types of initialize_megatron(), torch_dist_init(), and finish_mpu_init() to return ProcessGroupCollection; refactors random seed setup to use pg_collection ranks instead of MPU.
Training Setup & Orchestration
src/megatron/bridge/training/setup.py
Captures ProcessGroupCollection return from initialize_megatron(); propagates pg_collection to provide_distributed_model() and conditionally to setup_optimizer() based on use_decentralized_pg flag.
Optimizer Setup
src/megatron/bridge/training/optim.py
Adds optional `pg_collection: ProcessGroupCollection
Model Provider
src/megatron/bridge/models/model_provider.py
Adds optional `pg_collection: ProcessGroupCollection
Training Main Loop
src/megatron/bridge/training/train.py
Imports P2PCommunicator; creates P2PCommunicator instance with pg_collection groups; passes pp_size, vp_size, p2p_communicator, and pg_collection to forward-backward function; conditionally sets adjust_tensor_shapes_fn based on use_decentralized_pg; replaces last-rank check with is_pp_last_stage() helper.
Checkpointing
src/megatron/bridge/training/checkpointing.py
Replaces all MPU-based group references with ProcessGroupCollection; adds pg_collection: ProcessGroupCollection parameter to public functions (get_rng_state, generate_state_dict, _generate_model_state_dict, maybe_save_dataloader_state) and internal loading/saving paths; migrates RNG state collection, rank lookups, and group synchronization to use pg_collection groups.
Evaluation
src/megatron/bridge/training/eval.py
Adds P2PCommunicator instantiation and model_config retrieval; passes pp_size, vp_size, p2p_communicator, and pg_collection to forward-backward function calls.
Training Step Functions
src/megatron/bridge/training/gpt_step.py, src/megatron/bridge/training/llava_step.py, src/megatron/bridge/training/vlm_step.py
Updates batch processing to accept and propagate pg_collection; passes cp_group from pg_collection to context-parallel batch slicing; updates llava_step to use pg_collection.pp for pipeline stage detection.
Functional Tests
tests/functional_tests/training/test_decentralized_pg.py
New comprehensive test suite with autouse cleanup fixture and TestDecentralizedPgPretrain class covering pretraining with various parallelism configurations (TP, PP, CP, combined) with use_decentralized_pg enabled/disabled.
Unit Tests — Decentralized PG
tests/unit_tests/training/test_decentralized_pg.py
New extensive unit tests covering DistributedInitConfig flags, _create_pg_collection() with various parallelism configs, _set_random_seed() with pg_collection, initialization branching logic, optimizer setup propagation, checkpointing behavior, and tensor shape adjustments.
Unit Tests — Checkpointing & Infrastructure
tests/unit_tests/training/test_checkpointing.py, tests/unit_tests/training/test_peft_checkpointing.py
Replaces MPU mocks with pg_collection mocks; updates signatures to pass pg_collection to RNG/checkpoint functions; adds dist_checkpointing patches; integrates ProcessGroupCollection from MPU in setup paths.
Unit Tests — PEFT & Step Functions
tests/unit_tests/peft/test_*.py, tests/unit_tests/training/test_vlm_step.py
Integrates ProcessGroupCollection creation from MPU via ProcessGroupCollection.use_mpu_process_groups(); passes pg_collection to _set_random_seed(); adds mock process group objects with rank/size methods for batch/forward-step testing.
Functional Tests — Data & Datasets
tests/functional_tests/data/datasets/test_*.py, tests/functional_tests/data/energon/test_*.py
Creates ProcessGroupCollection from initialized MPU; passes pg_collection to _set_random_seed() for RNG coordination in dataset and data module tests.

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant Initialize
    participant HyperCommGrid
    participant PGCollection as ProcessGroupCollection
    participant Model
    participant Optimizer
    participant Training
    
    Client->>Initialize: initialize_megatron(cfg)
    Initialize->>Initialize: Check use_decentralized_pg flag
    alt Decentralized PG Enabled
        Initialize->>HyperCommGrid: Create grid with TP/PP/CP/DP sizes
        HyperCommGrid->>PGCollection: Build collection with all groups
        PGCollection->>Initialize: Return ProcessGroupCollection
    else Standard MPU Path
        Initialize->>Initialize: Use mpu initialization
        Initialize->>PGCollection: Create via use_mpu_process_groups()
    end
    Initialize->>Initialize: Set random seeds with pg_collection ranks
    Initialize-->>Client: Return ProcessGroupCollection
    
    Client->>Model: provide_distributed_model(cfg, pg_collection=pg_collection)
    Model->>Model: Skip MPU init (pg_collection provided)
    Model->>Model: Create model with pg_collection
    Model-->>Client: Return model
    
    Client->>Optimizer: setup_optimizer(cfg, pg_collection=pg_collection)
    Optimizer->>Optimizer: Create optimizer with pg_collection
    Optimizer-->>Client: Return optimizer
    
    Client->>Training: train_step(batch, pg_collection)
    Training->>Training: Create P2PCommunicator(pg_collection.pp)
    Training->>Training: Call forward_backward_func(pp_size, vp_size, p2p_communicator, pg_collection)
    Training->>Training: Use pg_collection for group operations
    Training-->>Client: Step complete
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Suggested reviewers

  • ananthsub
🚥 Pre-merge checks | ✅ 2 | ❌ 2
❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR introduces major distributed training changes across 20+ files (~2,000 lines) but lacks documented test execution results, performance metrics, and convergence validation. Add test results summary, performance benchmarks, convergence validation, and CI execution status for various parallelism configurations to the PR description.
Title check ❓ Inconclusive The title refers to 'M4 end2end support and qwen3 examples', but the changeset primarily focuses on adding decentralized process groups support (via ProcessGroupCollection and HyperCommGrid), with qwen3 examples being one use case. The main change is the infrastructure for local parallel groups, not just M4 support. Clarify whether 'M4' refers to a specific feature/model or if the title should emphasize 'decentralized process groups' or 'local parallel groups' as the primary change.
✅ Passed checks (2 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed Docstring coverage is 86.18% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
tests/unit_tests/training/test_vlm_step.py (1)

95-105: Monkeypatched CP slicing helper should accept cp_group kwarg

get_batch_on_this_cp_rank is now called with cp_group=..., but the monkeypatch only accepts a positional arg and will error. Allow cp_group/**kwargs to keep the tests green.

🐛 Proposed fix
-    monkeypatch.setattr(
-        "megatron.core.utils.get_batch_on_this_cp_rank",
-        lambda x: x,
-        raising=True,
-    )
+    monkeypatch.setattr(
+        "megatron.core.utils.get_batch_on_this_cp_rank",
+        lambda x, **_: x,
+        raising=True,
+    )
-    monkeypatch.setattr("megatron.core.utils.get_batch_on_this_cp_rank", lambda x: x, raising=True)
+    monkeypatch.setattr("megatron.core.utils.get_batch_on_this_cp_rank", lambda x, **_: x, raising=True)

Also applies to: 166-168

src/megatron/bridge/training/checkpointing.py (1)

560-579: Avoid serializing ProcessGroup objects in checkpoint metadata.
sharded_sd_metadata is also passed as content_metadata; persisting pg_collection.dp_cp can break serialization. Consider keeping a separate metadata dict for sharding only.

🛠️ Proposed fix
-    sharded_sd_metadata = _build_sharded_state_dict_metadata(cfg.optimizer.use_distributed_optimizer, ckpt_cfg)
-    sharded_sd_metadata["dp_cp_group"] = pg_collection.dp_cp
+    sharded_sd_metadata = _build_sharded_state_dict_metadata(cfg.optimizer.use_distributed_optimizer, ckpt_cfg)
+    sharding_sd_metadata = dict(sharded_sd_metadata)
+    sharding_sd_metadata["dp_cp_group"] = pg_collection.dp_cp
...
-        optim_sd_kwargs=dict(metadata=sharded_sd_metadata),
-        model_sd_kwargs=dict(metadata=sharded_sd_metadata),
+        optim_sd_kwargs=dict(metadata=sharding_sd_metadata),
+        model_sd_kwargs=dict(metadata=sharding_sd_metadata),
src/megatron/bridge/training/llava_step.py (1)

88-134: Fix assertion for last‑stage‑only pipelines.
When is_last is true and is_first is false, tokens/input_ids are intentionally omitted, but the unconditional assertion fails. Make the checks stage‑specific.

🛠️ Proposed fix
-    assert batch.get("tokens") is not None or batch.get("input_ids") is not None, "tokens or input_ids must be present"
+    if is_first:
+        assert (
+            batch.get("tokens") is not None or batch.get("input_ids") is not None
+        ), "tokens or input_ids must be present on the first PP stage"
+    if is_last:
+        assert (
+            batch.get("labels") is not None and batch.get("loss_mask") is not None
+        ), "labels/loss_mask must be present on the last PP stage"
src/megatron/bridge/training/initialize.py (1)

60-163: Docstrings should reflect ProcessGroupCollection return values.
Both initialize_megatron and torch_dist_init can now return a ProcessGroupCollection, but the docs still state “callable or None.”

Docstring update
-    Returns:
-        An optional callable to finish MPU initialization if lazy_mpu_init is True,
-        otherwise None.
+    Returns:
+        If lazy init is enabled, returns a callable that completes initialization and returns
+        a ProcessGroupCollection; otherwise returns the ProcessGroupCollection directly.
-    Returns:
-        An optional callable to finish MPU initialization if skip_mpu_initialization
-        or lazy_mpu_init is True, otherwise None.
+    Returns:
+        If lazy init is enabled, returns a callable that completes initialization and returns
+        a ProcessGroupCollection; otherwise returns the ProcessGroupCollection directly.
🤖 Fix all issues with AI agents
In `@examples/recipes/decentralized_pg/pretrain_qwen3_with_decentralized_pg.py`:
- Around line 555-603: The variable test_data_iterator returned from
setup_data_iterators is unused and will trigger RUF059; rename or prefix it to
_test_data_iterator (or otherwise discard it) where it's assigned to silence the
lint warning — update the assignment that currently reads "train_data_iterator,
valid_data_iterator, test_data_iterator = setup_data_iterators(...)" to use
"train_data_iterator, valid_data_iterator, _test_data_iterator =
setup_data_iterators(...)" so callers (and tools) know the test iterator is
intentionally unused.

In `@examples/recipes/decentralized_pg/README.md`:
- Around line 154-168: The fenced code block under the "HyperCommGrid Explained"
section (the rank layout snippet) is missing a language tag; update the
triple-backtick fence to include a language such as "text" (i.e., change ``` to
```text) so markdownlint MD040 is satisfied and the snippet is treated as plain
text; the block is the one showing "World Size = 8, Shape = [2, 1, 2, 2] means:"
and the subsequent rank layout for HyperCommGrid.

In `@src/megatron/bridge/training/initialize.py`:
- Around line 436-454: The decentralized path always builds
embedding_rank_lists/pos_embedding_rank_lists from the default first/last
PP-stage logic, ignoring the initialize_megatron hooks; modify
_create_pg_collection to accept and use the get_embedding_ranks and
get_position_embedding_ranks callbacks (or their returned lists) when provided,
populating embedding_rank_lists and pos_embedding_rank_lists from those
callbacks instead of always using the current first/last logic, and only fall
back to the existing logic that builds
embedding_rank_lists/pos_embedding_rank_lists from pp_rank_lists when the
callbacks are None; ensure embd_pg and pos_embd_pg are still created via
torch.distributed.new_subgroups_by_enumeration using the final lists so custom
layouts are honored.
- Around line 556-557: The unconditional RuntimeError on device_count==0 should
be restricted to only the workflows that require CUDA; replace the unconditional
check with a guarded check that raises only when no CUDA devices are available
AND the code is initializing decentralized/NCCL groups — e.g. change the if to
something like: if device_count == 0 and (decentralized or is_decentralized) and
(backend == 'nccl' or dist_backend == 'nccl'): raise RuntimeError(...). Update
the check near the parallel-group initialization so it references the existing
variables device_count, decentralized/is_decentralized and backend/dist_backend.

In `@src/megatron/bridge/training/train.py`:
- Around line 253-256: The call to get_forward_backward_func may receive vp_size
as None because config.model.virtual_pipeline_model_parallel_size can be unset;
guard it by defaulting vp_size to 1 before constructing the schedule (e.g.,
compute vp_size = config.model.virtual_pipeline_model_parallel_size or 1) and
then call get_forward_backward_func(pp_size=pg_collection.pp.size(),
vp_size=vp_size) so the schedule construction never gets a None vp_size.
🧹 Nitpick comments (2)
src/megatron/bridge/training/checkpointing.py (1)

1026-1028: Silence the unused pg_collection parameter.
_generate_model_state_dict doesn’t use pg_collection, but call sites pass it. Consider marking it intentionally unused to satisfy Ruff.

♻️ Suggested tweak
 def _generate_model_state_dict(
     model: list[MegatronModule],
     model_sd_kwargs: Optional[dict[str, Any]] = None,
     ckpt_format: str = "torch_dist",
     *,
     pg_collection: ProcessGroupCollection | None = None,
 ) -> dict[str, ShardedStateDict]:
+    _ = pg_collection  # intentionally unused (signature parity)

Also applies to: 1093-1095

src/megatron/bridge/training/eval.py (1)

213-225: Consider reusing P2PCommunicator instance.

A second P2PCommunicator is created here for the non-loss data collection path. Since both instances use identical parameters (pp_group=pg_collection.pp, config=model_config), consider reusing the first instance by moving its creation before the main evaluation loop.

However, this is a minor optimization and the current implementation is functionally correct.

Comment on lines +436 to +454
# Embedding and position-embedding groups
embd_pg = None
pos_embd_pg = None
# Enumerate ranks per PP group
pp_rank_lists = grid._gen_rank_enum(["pp"])
# Determine embedding ranks for each pp group
embedding_rank_lists: list[list[int]] = []
pos_embedding_rank_lists: list[list[int]] = []
for ranks in pp_rank_lists:
if not ranks:
continue
# embedding_ranks: first and last pp stage (or only one if pp_size==1)
embedding_rank_lists.append([ranks[0]] if len(ranks) == 1 else [ranks[0], ranks[-1]])
pos_embedding_rank_lists.append([ranks[0]])
if embedding_rank_lists:
embd_pg, _ = torch.distributed.new_subgroups_by_enumeration(embedding_rank_lists, backend="nccl")
if pos_embedding_rank_lists:
pos_embd_pg, _ = torch.distributed.new_subgroups_by_enumeration(pos_embedding_rank_lists, backend="nccl")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Custom embedding-rank hooks are ignored in the decentralized path.
initialize_megatron accepts get_embedding_ranks / get_position_embedding_ranks, but _create_pg_collection always uses default first/last-stage logic, which changes behavior for custom layouts. This is a functional gap from the MPU path.

Wire embedding rank hooks into _create_pg_collection
-def _create_pg_collection(
-    model_config: GPTModelProvider | T5ModelProvider, num_distributed_optimizer_instances: int
-) -> ProcessGroupCollection:
+def _create_pg_collection(
+    model_config: GPTModelProvider | T5ModelProvider,
+    num_distributed_optimizer_instances: int,
+    *,
+    get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
+    get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
+) -> ProcessGroupCollection:
@@
-    for ranks in pp_rank_lists:
+    for ranks in pp_rank_lists:
         if not ranks:
             continue
-        # embedding_ranks: first and last pp stage (or only one if pp_size==1)
-        embedding_rank_lists.append([ranks[0]] if len(ranks) == 1 else [ranks[0], ranks[-1]])
-        pos_embedding_rank_lists.append([ranks[0]])
+        if get_embedding_ranks is not None:
+            embedding_rank_lists.append(get_embedding_ranks(ranks, len(ranks)))
+        else:
+            embedding_rank_lists.append([ranks[0]] if len(ranks) == 1 else [ranks[0], ranks[-1]])
+        if get_position_embedding_ranks is not None:
+            pos_embedding_rank_lists.append(get_position_embedding_ranks(ranks, len(ranks)))
+        else:
+            pos_embedding_rank_lists.append([ranks[0]])
-        pg_collection = _create_pg_collection(model_config, num_distributed_optimizer_instances)
+        pg_collection = _create_pg_collection(
+            model_config,
+            num_distributed_optimizer_instances,
+            get_embedding_ranks=get_embedding_ranks,
+            get_position_embedding_ranks=get_position_embedding_ranks,
+        )
🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/initialize.py` around lines 436 - 454, The
decentralized path always builds embedding_rank_lists/pos_embedding_rank_lists
from the default first/last PP-stage logic, ignoring the initialize_megatron
hooks; modify _create_pg_collection to accept and use the get_embedding_ranks
and get_position_embedding_ranks callbacks (or their returned lists) when
provided, populating embedding_rank_lists and pos_embedding_rank_lists from
those callbacks instead of always using the current first/last logic, and only
fall back to the existing logic that builds
embedding_rank_lists/pos_embedding_rank_lists from pp_rank_lists when the
callbacks are None; ensure embd_pg and pos_embd_pg are still created via
torch.distributed.new_subgroups_by_enumeration using the final lists so custom
layouts are honored.

yaoyu-33 and others added 3 commits January 26, 2026 19:14
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@yaoyu-33
Copy link
Contributor Author

/ok to test b6912a7

…it tests

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@yaoyu-33
Copy link
Contributor Author

/ok to test b6a3c14

# Conflicts:
#	src/megatron/bridge/training/checkpointing.py
@yaoyu-33
Copy link
Contributor Author

/ok to test f5495b6

@yaoyu-33
Copy link
Contributor Author

/ok to test 6e11df4

@yaoyu-33
Copy link
Contributor Author

/ok to test d0cb1bf

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants