Skip to content

fix: enhancing non-colocated refit performance by having inclusive comm group#1264

Merged
terrykong merged 4 commits intoNVIDIA-NeMo:mainfrom
youngeunkwon0405:async-refit
Oct 8, 2025
Merged

fix: enhancing non-colocated refit performance by having inclusive comm group#1264
terrykong merged 4 commits intoNVIDIA-NeMo:mainfrom
youngeunkwon0405:async-refit

Conversation

@youngeunkwon0405
Copy link
Copy Markdown
Contributor

@youngeunkwon0405 youngeunkwon0405 commented Oct 3, 2025

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

image

Speedup

QWEN3 32B run

2x speedup in QWEN3 32B model run

image [Wandb](https://wandb.ai/nvidia/async-grpo-refit?nw=nwuseryoungeunk&panelDisplayName=timing%2Ftrain%2Fweight_sync&panelSectionName=timing )

Nsys report

Current

Broadcast is much larger than the all-gather.
image

This PR

The latency of broadcast and all-gather is similar.
image

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • New Features
    • Explicit train_world_size support across training and inference collectives for more reliable multi-node/non-colocated runs.
    • All training ranks now participate in NCCL; weights broadcast uniformly to all ranks.
  • Bug Fixes
    • Correct world size calculation and rank assignment for non-colocated inference and VLLM workers.
  • Refactor
    • init_collective signatures updated to require train_world_size across policies and VLLM components; callers must pass this parameter.

@youngeunkwon0405 youngeunkwon0405 self-assigned this Oct 3, 2025
@youngeunkwon0405 youngeunkwon0405 requested review from a team as code owners October 3, 2025 00:46
@youngeunkwon0405 youngeunkwon0405 marked this pull request as draft October 3, 2025 00:46
@github-actions
Copy link
Copy Markdown

github-actions bot commented Oct 3, 2025

⚠️ File Consistency Check

Check based on commit: e06d82e (PR #1264 from async-refit)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/dtensor_policy_worker.py was modified in this PR, but nemo_rl/models/policy/dtensor_policy_worker_v2.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/dtensor_policy_worker.py should also be applied to nemo_rl/models/policy/dtensor_policy_worker_v2.py
  • Update nemo_rl/models/policy/dtensor_policy_worker_v2.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/dtensor_policy_worker.py
  • Not modified: nemo_rl/models/policy/dtensor_policy_worker_v2.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@youngeunkwon0405 youngeunkwon0405 changed the title fix: enhancing non-colocated refit performance by including all training ranks to the communication group fix: Enhancing non-colocated refit performance by including all training ranks to the communication group Oct 3, 2025
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Oct 3, 2025

📝 Walkthrough

Walkthrough

Updates propagate explicit train world size through distributed initialization. World size is now computed as train_cluster.world_size() plus inference workers. init_collective signatures gain a train_world_size parameter across policy, VLLM generation stack, and workers. Rank calculation and communicator setup are adjusted so all train ranks join, removing prior rank-0 special-casing.

Changes

Cohort / File(s) Summary of changes
Algorithms: world size calc and init wiring
nemo_rl/algorithms/distillation.py, nemo_rl/algorithms/grpo.py
Compute train_world_size from train_cluster.world_size(); compute total world_size as train + inference; pass train_world_size into init_collective for training and inference components.
VLLM generation stack: propagate train_world_size
nemo_rl/models/generation/vllm/vllm_generation.py, nemo_rl/models/generation/vllm/vllm_worker.py, nemo_rl/models/generation/vllm/vllm_worker_async.py, nemo_rl/models/generation/vllm/vllm_backend.py
Add train_world_size parameter to init_collective across generation classes; adjust rank computation to offset by train_world_size; forward parameter through collective_rpc.
Policy interface and LM policy
nemo_rl/models/policy/interfaces.py, nemo_rl/models/policy/lm_policy.py
Update ColocatablePolicyInterface and Policy.init_collective signatures to include keyword-only train_world_size; forward parameter to worker group invocation.
DTensor policy worker
nemo_rl/models/policy/dtensor_policy_worker.py
Add keyword-only train_world_size to init_collective; have all train ranks join communicator with rank=self.rank; unify broadcast path across ranks (remove rank-0-only branches).
Megatron policy worker
nemo_rl/models/policy/megatron_policy_worker.py
Add keyword-only train_world_size to init_collective; all train ranks join PyNcclCommunicator with rank=self.rank; unify weight broadcast from train rank 0 to all ranks; update comments/docstring.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Trainer[Training Ranks (0..train_world_size-1)]
  participant Orchestrator[Init Orchestrator]
  participant Inference[Inference Workers (VLLM)]

  Note over Orchestrator: Compute train_world_size and inference_world_size<br/>world_size = train + inference

  Orchestrator->>Trainer: init_collective(ip, port, world_size, train_world_size)
  loop for each training rank
    Trainer->>Trainer: Join NCCL (rank = self.rank)
  end

  Orchestrator->>Inference: init_collective(ip, port, world_size, train_world_size)
  loop for each inference worker
    Inference->>Inference: rank = train_world_size + rank_prefix + local_rank
    Inference->>Inference: Join NCCL
  end

  Note over Trainer,Inference: Single communicator across training + inference

  Trainer->>Inference: Broadcast weights (src=train rank 0)
  Inference-->>Trainer: Acks
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Pre-merge checks and finishing touches

✅ Passed checks (4 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed Docstring coverage is 42.86% which is insufficient. The required threshold is 80.00%.
Test Results For Major Changes ✅ Passed
Title Check ✅ Passed The title succinctly captures the primary change of including all training ranks in the communication group to improve non-colocated refit performance.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
nemo_rl/models/policy/megatron_policy_worker.py (1)

823-841: Leverage train_world_size to validate communicator ranks.

train_world_size is never used, so a mis-sized training group (or an unexpected torch rank) would silently collide with inference ranks when the NCCL communicator is created. Guard the rank against train_world_size before calling StatelessProcessGroup.create so we fail fast on bad deployments.

     from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
     from vllm.distributed.utils import StatelessProcessGroup
 
-        # All training ranks [0..train_world_size-1] join the communicator with their rank
+        if not 0 <= self.rank < train_world_size:
+            raise ValueError(
+                f"Megatron rank {self.rank} must be < train_world_size ({train_world_size})"
+            )
+
+        # All training ranks [0..train_world_size-1] join the communicator with their rank
         pg = StatelessProcessGroup.create(
             host=ip, port=port, rank=self.rank, world_size=world_size
         )
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 557b7ec and e06d82e.

📒 Files selected for processing (10)
  • nemo_rl/algorithms/distillation.py (1 hunks)
  • nemo_rl/algorithms/grpo.py (1 hunks)
  • nemo_rl/models/generation/vllm/vllm_backend.py (1 hunks)
  • nemo_rl/models/generation/vllm/vllm_generation.py (2 hunks)
  • nemo_rl/models/generation/vllm/vllm_worker.py (2 hunks)
  • nemo_rl/models/generation/vllm/vllm_worker_async.py (2 hunks)
  • nemo_rl/models/policy/dtensor_policy_worker.py (2 hunks)
  • nemo_rl/models/policy/interfaces.py (1 hunks)
  • nemo_rl/models/policy/lm_policy.py (1 hunks)
  • nemo_rl/models/policy/megatron_policy_worker.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts

Files:

  • nemo_rl/algorithms/distillation.py
  • nemo_rl/models/generation/vllm/vllm_worker.py
  • nemo_rl/models/generation/vllm/vllm_worker_async.py
  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/generation/vllm/vllm_generation.py
  • nemo_rl/models/policy/lm_policy.py
  • nemo_rl/models/policy/interfaces.py
  • nemo_rl/models/generation/vllm/vllm_backend.py
  • nemo_rl/models/policy/megatron_policy_worker.py
  • nemo_rl/algorithms/grpo.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)

Files:

  • nemo_rl/algorithms/distillation.py
  • nemo_rl/models/generation/vllm/vllm_worker.py
  • nemo_rl/models/generation/vllm/vllm_worker_async.py
  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/generation/vllm/vllm_generation.py
  • nemo_rl/models/policy/lm_policy.py
  • nemo_rl/models/policy/interfaces.py
  • nemo_rl/models/generation/vllm/vllm_backend.py
  • nemo_rl/models/policy/megatron_policy_worker.py
  • nemo_rl/algorithms/grpo.py
🧬 Code graph analysis (7)
nemo_rl/algorithms/distillation.py (8)
nemo_rl/distributed/virtual_cluster.py (1)
  • world_size (357-358)
nemo_rl/models/generation/vllm/vllm_backend.py (1)
  • init_collective (34-55)
nemo_rl/models/generation/vllm/vllm_generation.py (1)
  • init_collective (370-407)
nemo_rl/models/generation/vllm/vllm_worker.py (1)
  • init_collective (479-496)
nemo_rl/models/policy/dtensor_policy_worker.py (1)
  • init_collective (504-518)
nemo_rl/models/policy/interfaces.py (1)
  • init_collective (143-146)
nemo_rl/models/policy/lm_policy.py (1)
  • init_collective (236-248)
nemo_rl/models/policy/megatron_policy_worker.py (1)
  • init_collective (823-840)
nemo_rl/models/policy/dtensor_policy_worker.py (7)
nemo_rl/models/generation/vllm/vllm_backend.py (1)
  • init_collective (34-55)
nemo_rl/models/generation/vllm/vllm_generation.py (1)
  • init_collective (370-407)
nemo_rl/models/generation/vllm/vllm_worker.py (1)
  • init_collective (479-496)
nemo_rl/models/policy/interfaces.py (1)
  • init_collective (143-146)
nemo_rl/models/policy/lm_policy.py (1)
  • init_collective (236-248)
nemo_rl/models/policy/megatron_policy_worker.py (1)
  • init_collective (823-840)
nemo_rl/models/policy/dtensor_policy_worker_v2.py (1)
  • init_collective (462-472)
nemo_rl/models/policy/lm_policy.py (2)
nemo_rl/distributed/virtual_cluster.py (1)
  • world_size (357-358)
nemo_rl/distributed/worker_groups.py (1)
  • run_all_workers_single_data (728-772)
nemo_rl/models/policy/interfaces.py (1)
nemo_rl/distributed/virtual_cluster.py (1)
  • world_size (357-358)
nemo_rl/models/generation/vllm/vllm_backend.py (1)
nemo_rl/distributed/virtual_cluster.py (1)
  • world_size (357-358)
nemo_rl/models/policy/megatron_policy_worker.py (5)
nemo_rl/models/generation/vllm/vllm_backend.py (1)
  • init_collective (34-55)
nemo_rl/models/generation/vllm/vllm_generation.py (1)
  • init_collective (370-407)
nemo_rl/models/policy/dtensor_policy_worker.py (1)
  • init_collective (504-518)
nemo_rl/models/policy/interfaces.py (1)
  • init_collective (143-146)
nemo_rl/models/policy/lm_policy.py (1)
  • init_collective (236-248)
nemo_rl/algorithms/grpo.py (9)
nemo_rl/distributed/virtual_cluster.py (1)
  • world_size (357-358)
nemo_rl/models/generation/vllm/vllm_backend.py (1)
  • init_collective (34-55)
nemo_rl/models/generation/vllm/vllm_generation.py (1)
  • init_collective (370-407)
nemo_rl/models/generation/vllm/vllm_worker.py (1)
  • init_collective (479-496)
nemo_rl/models/policy/dtensor_policy_worker.py (1)
  • init_collective (504-518)
nemo_rl/models/policy/interfaces.py (1)
  • init_collective (143-146)
nemo_rl/models/policy/lm_policy.py (1)
  • init_collective (236-248)
nemo_rl/models/policy/megatron_policy_worker.py (1)
  • init_collective (823-840)
nemo_rl/models/generation/interfaces.py (1)
  • init_collective (212-216)
🪛 Ruff (0.13.2)
nemo_rl/models/policy/dtensor_policy_worker.py

505-505: Unused method argument: train_world_size

(ARG002)

nemo_rl/models/policy/megatron_policy_worker.py

824-824: Unused method argument: train_world_size

(ARG002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Lint check
  • GitHub Check: Post submodule check comment / Comment on PR

@youngeunkwon0405 youngeunkwon0405 changed the title fix: Enhancing non-colocated refit performance by including all training ranks to the communication group fix: enhancing non-colocated refit performance by having inclusive comm group Oct 3, 2025
@youngeunkwon0405 youngeunkwon0405 added the Performance Related to improving performance label Oct 3, 2025
@youngeunkwon0405 youngeunkwon0405 marked this pull request as ready for review October 3, 2025 03:57
@github-actions
Copy link
Copy Markdown

github-actions bot commented Oct 3, 2025

ℹ️ File Consistency Check

Check based on commit: d4840d8 (PR #1264 from async-refit)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@github-actions
Copy link
Copy Markdown

github-actions bot commented Oct 3, 2025

ℹ️ File Consistency Check

Check based on commit: 33e64a5 (PR #1264 from async-refit)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@youngeunkwon0405 youngeunkwon0405 added the CI:L1 Run doctests, unit tests, and functional tests label Oct 3, 2025
@youngeunkwon0405 youngeunkwon0405 requested a review from a team as a code owner October 3, 2025 17:30
@github-actions
Copy link
Copy Markdown

github-actions bot commented Oct 3, 2025

ℹ️ File Consistency Check

Check based on commit: 2081613 (PR #1264 from async-refit)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@youngeunkwon0405 youngeunkwon0405 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Oct 3, 2025
@youngeunkwon0405 youngeunkwon0405 removed the CI:L1 Run doctests, unit tests, and functional tests label Oct 7, 2025
@youngeunkwon0405 youngeunkwon0405 added the CI:L0 Run doctests and unit tests label Oct 7, 2025
@github-actions
Copy link
Copy Markdown

github-actions bot commented Oct 7, 2025

ℹ️ File Consistency Check

Check based on commit: 69a481f (PR #1264 from async-refit)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
@youngeunkwon0405 youngeunkwon0405 added CI:L0 Run doctests and unit tests and removed CI:L0 Run doctests and unit tests labels Oct 7, 2025
@github-actions
Copy link
Copy Markdown

github-actions bot commented Oct 7, 2025

ℹ️ File Consistency Check

Check based on commit: 84f92bf (PR #1264 from async-refit)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@youngeunkwon0405 youngeunkwon0405 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L0 Run doctests and unit tests labels Oct 7, 2025
@youngeunkwon0405 youngeunkwon0405 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Oct 7, 2025
@youngeunkwon0405
Copy link
Copy Markdown
Contributor Author

youngeunkwon0405 commented Oct 8, 2025

@guyueh1 now it is pipe cleaned, but your review has become stale. Added test case is the only change compared to the past. Can I get your re-approval? For the next step, maybe we can ask @terrykong for a merge.

@terrykong terrykong enabled auto-merge (squash) October 8, 2025 06:20
Copy link
Copy Markdown
Collaborator

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. awaiting @guyueh1 's review which will trigger merge

Copy link
Copy Markdown
Contributor

@guyueh1 guyueh1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great feature, thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests Performance Related to improving performance r0.4.0

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Refit speedup for non-colocated (NCCL)

4 participants