fix: enhancing non-colocated refit performance by having inclusive comm group#1264
Conversation
|
📝 WalkthroughWalkthroughUpdates propagate explicit train world size through distributed initialization. World size is now computed as train_cluster.world_size() plus inference workers. init_collective signatures gain a train_world_size parameter across policy, VLLM generation stack, and workers. Rank calculation and communicator setup are adjusted so all train ranks join, removing prior rank-0 special-casing. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Trainer[Training Ranks (0..train_world_size-1)]
participant Orchestrator[Init Orchestrator]
participant Inference[Inference Workers (VLLM)]
Note over Orchestrator: Compute train_world_size and inference_world_size<br/>world_size = train + inference
Orchestrator->>Trainer: init_collective(ip, port, world_size, train_world_size)
loop for each training rank
Trainer->>Trainer: Join NCCL (rank = self.rank)
end
Orchestrator->>Inference: init_collective(ip, port, world_size, train_world_size)
loop for each inference worker
Inference->>Inference: rank = train_world_size + rank_prefix + local_rank
Inference->>Inference: Join NCCL
end
Note over Trainer,Inference: Single communicator across training + inference
Trainer->>Inference: Broadcast weights (src=train rank 0)
Inference-->>Trainer: Acks
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Pre-merge checks and finishing touches✅ Passed checks (4 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
nemo_rl/models/policy/megatron_policy_worker.py (1)
823-841: Leveragetrain_world_sizeto validate communicator ranks.
train_world_sizeis never used, so a mis-sized training group (or an unexpected torch rank) would silently collide with inference ranks when the NCCL communicator is created. Guard the rank againsttrain_world_sizebefore callingStatelessProcessGroup.createso we fail fast on bad deployments.from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.utils import StatelessProcessGroup - # All training ranks [0..train_world_size-1] join the communicator with their rank + if not 0 <= self.rank < train_world_size: + raise ValueError( + f"Megatron rank {self.rank} must be < train_world_size ({train_world_size})" + ) + + # All training ranks [0..train_world_size-1] join the communicator with their rank pg = StatelessProcessGroup.create( host=ip, port=port, rank=self.rank, world_size=world_size )
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
nemo_rl/algorithms/distillation.py(1 hunks)nemo_rl/algorithms/grpo.py(1 hunks)nemo_rl/models/generation/vllm/vllm_backend.py(1 hunks)nemo_rl/models/generation/vllm/vllm_generation.py(2 hunks)nemo_rl/models/generation/vllm/vllm_worker.py(2 hunks)nemo_rl/models/generation/vllm/vllm_worker_async.py(2 hunks)nemo_rl/models/policy/dtensor_policy_worker.py(2 hunks)nemo_rl/models/policy/interfaces.py(1 hunks)nemo_rl/models/policy/lm_policy.py(1 hunks)nemo_rl/models/policy/megatron_policy_worker.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
nemo_rl/algorithms/distillation.pynemo_rl/models/generation/vllm/vllm_worker.pynemo_rl/models/generation/vllm/vllm_worker_async.pynemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/models/policy/lm_policy.pynemo_rl/models/policy/interfaces.pynemo_rl/models/generation/vllm/vllm_backend.pynemo_rl/models/policy/megatron_policy_worker.pynemo_rl/algorithms/grpo.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/algorithms/distillation.pynemo_rl/models/generation/vllm/vllm_worker.pynemo_rl/models/generation/vllm/vllm_worker_async.pynemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/models/policy/lm_policy.pynemo_rl/models/policy/interfaces.pynemo_rl/models/generation/vllm/vllm_backend.pynemo_rl/models/policy/megatron_policy_worker.pynemo_rl/algorithms/grpo.py
🧬 Code graph analysis (7)
nemo_rl/algorithms/distillation.py (8)
nemo_rl/distributed/virtual_cluster.py (1)
world_size(357-358)nemo_rl/models/generation/vllm/vllm_backend.py (1)
init_collective(34-55)nemo_rl/models/generation/vllm/vllm_generation.py (1)
init_collective(370-407)nemo_rl/models/generation/vllm/vllm_worker.py (1)
init_collective(479-496)nemo_rl/models/policy/dtensor_policy_worker.py (1)
init_collective(504-518)nemo_rl/models/policy/interfaces.py (1)
init_collective(143-146)nemo_rl/models/policy/lm_policy.py (1)
init_collective(236-248)nemo_rl/models/policy/megatron_policy_worker.py (1)
init_collective(823-840)
nemo_rl/models/policy/dtensor_policy_worker.py (7)
nemo_rl/models/generation/vllm/vllm_backend.py (1)
init_collective(34-55)nemo_rl/models/generation/vllm/vllm_generation.py (1)
init_collective(370-407)nemo_rl/models/generation/vllm/vllm_worker.py (1)
init_collective(479-496)nemo_rl/models/policy/interfaces.py (1)
init_collective(143-146)nemo_rl/models/policy/lm_policy.py (1)
init_collective(236-248)nemo_rl/models/policy/megatron_policy_worker.py (1)
init_collective(823-840)nemo_rl/models/policy/dtensor_policy_worker_v2.py (1)
init_collective(462-472)
nemo_rl/models/policy/lm_policy.py (2)
nemo_rl/distributed/virtual_cluster.py (1)
world_size(357-358)nemo_rl/distributed/worker_groups.py (1)
run_all_workers_single_data(728-772)
nemo_rl/models/policy/interfaces.py (1)
nemo_rl/distributed/virtual_cluster.py (1)
world_size(357-358)
nemo_rl/models/generation/vllm/vllm_backend.py (1)
nemo_rl/distributed/virtual_cluster.py (1)
world_size(357-358)
nemo_rl/models/policy/megatron_policy_worker.py (5)
nemo_rl/models/generation/vllm/vllm_backend.py (1)
init_collective(34-55)nemo_rl/models/generation/vllm/vllm_generation.py (1)
init_collective(370-407)nemo_rl/models/policy/dtensor_policy_worker.py (1)
init_collective(504-518)nemo_rl/models/policy/interfaces.py (1)
init_collective(143-146)nemo_rl/models/policy/lm_policy.py (1)
init_collective(236-248)
nemo_rl/algorithms/grpo.py (9)
nemo_rl/distributed/virtual_cluster.py (1)
world_size(357-358)nemo_rl/models/generation/vllm/vllm_backend.py (1)
init_collective(34-55)nemo_rl/models/generation/vllm/vllm_generation.py (1)
init_collective(370-407)nemo_rl/models/generation/vllm/vllm_worker.py (1)
init_collective(479-496)nemo_rl/models/policy/dtensor_policy_worker.py (1)
init_collective(504-518)nemo_rl/models/policy/interfaces.py (1)
init_collective(143-146)nemo_rl/models/policy/lm_policy.py (1)
init_collective(236-248)nemo_rl/models/policy/megatron_policy_worker.py (1)
init_collective(823-840)nemo_rl/models/generation/interfaces.py (1)
init_collective(212-216)
🪛 Ruff (0.13.2)
nemo_rl/models/policy/dtensor_policy_worker.py
505-505: Unused method argument: train_world_size
(ARG002)
nemo_rl/models/policy/megatron_policy_worker.py
824-824: Unused method argument: train_world_size
(ARG002)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Lint check
- GitHub Check: Post submodule check comment / Comment on PR
e06d82e to
d4840d8
Compare
ℹ️ File Consistency CheckCheck based on commit: d4840d8 (PR #1264 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
d4840d8 to
33e64a5
Compare
ℹ️ File Consistency CheckCheck based on commit: 33e64a5 (PR #1264 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: 2081613 (PR #1264 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
6eea30b to
69a481f
Compare
ℹ️ File Consistency CheckCheck based on commit: 69a481f (PR #1264 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
69a481f to
84f92bf
Compare
ℹ️ File Consistency CheckCheck based on commit: 84f92bf (PR #1264 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
|
@guyueh1 now it is pipe cleaned, but your review has become stale. Added test case is the only change compared to the past. Can I get your re-approval? For the next step, maybe we can ask @terrykong for a merge. |
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Speedup
QWEN3 32B run
2x speedup in QWEN3 32B model run
Nsys report
Current
Broadcast is much larger than the all-gather.

This PR
The latency of broadcast and all-gather is similar.

Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit