Skip to content

Commit 33e64a5

Browse files
draft
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
1 parent 557b7ec commit 33e64a5

File tree

11 files changed

+85
-47
lines changed

11 files changed

+85
-47
lines changed

nemo_rl/algorithms/distillation.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,11 +424,16 @@ def setup(
424424
if not colocated_inference:
425425
ip, port = train_cluster.get_master_address_and_port()
426426
print(f"Using ip: {ip}, port: {port} for collective communication", flush=True)
427+
train_world_size = train_cluster.world_size()
427428
# inference cluster + head node of the train cluster
428-
world_size = inference_nodes * inference_gpus_per_node + 1
429+
world_size = train_world_size + inference_nodes * inference_gpus_per_node
429430
# init collective
430-
futures_train = student_policy.init_collective(ip, port, world_size)
431-
futures_inference = student_generation.init_collective(ip, port, world_size) # type: ignore
431+
futures_train = student_policy.init_collective(
432+
ip, port, world_size, train_world_size=train_world_size
433+
)
434+
futures_inference = student_generation.init_collective(
435+
ip, port, world_size, train_world_size=train_world_size
436+
) # type: ignore
432437
# wait for all futures to complete
433438
ray.get(futures_train + futures_inference)
434439

nemo_rl/algorithms/grpo.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -432,11 +432,17 @@ def setup(
432432
if not colocated_inference:
433433
ip, port = train_cluster.get_master_address_and_port()
434434
print(f"Using ip: {ip}, port: {port} for collective communication", flush=True)
435-
# inference cluster + head node of the train cluster
436-
world_size = inference_nodes * inference_gpus_per_node + 1
435+
# world includes all training workers and all inference workers
436+
train_world_size = train_cluster.world_size()
437+
inference_world_size = inference_nodes * inference_gpus_per_node
438+
world_size = train_world_size + inference_world_size
437439
# init collective
438-
futures_train = policy.init_collective(ip, port, world_size)
439-
futures_inference = policy_generation.init_collective(ip, port, world_size) # type: ignore
440+
futures_train = policy.init_collective(
441+
ip, port, world_size, train_world_size=train_world_size
442+
)
443+
futures_inference = policy_generation.init_collective(
444+
ip, port, world_size, train_world_size=train_world_size
445+
) # type: ignore
440446
# wait for all futures to complete
441447
ray.get(futures_train + futures_inference)
442448

nemo_rl/models/generation/vllm/vllm_backend.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,20 @@
3232

3333
class VllmInternalWorkerExtension:
3434
def init_collective(
35-
self, rank_prefix: int, ip: str, port: int, world_size: int
35+
self,
36+
rank_prefix: int,
37+
ip: str,
38+
port: int,
39+
world_size: int,
40+
train_world_size: int,
3641
) -> None:
3742
"""Initialize the collective communication."""
3843
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
3944
from vllm.distributed.utils import StatelessProcessGroup
4045

4146
local_rank = torch.distributed.get_rank()
42-
rank = rank_prefix + local_rank + 1 # 1 is the head node of the train cluster
47+
# Place vLLM ranks after all training ranks so all training workers can join
48+
rank = train_world_size + rank_prefix + local_rank
4349

4450
pg = StatelessProcessGroup.create(
4551
host=ip, port=port, rank=rank, world_size=world_size

nemo_rl/models/generation/vllm/vllm_generation.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def _post_init(self):
368368
return results
369369

370370
def init_collective(
371-
self, ip: str, port: int, world_size: int
371+
self, ip: str, port: int, world_size: int, *, train_world_size: int
372372
) -> list[ray.ObjectRef]:
373373
"""Initialize the collective communication."""
374374
if not self.worker_group or not self.worker_group.workers:
@@ -395,7 +395,12 @@ def init_collective(
395395
method_name,
396396
rank_prefix=rank_prefix_list,
397397
run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"],
398-
common_kwargs={"ip": ip, "port": port, "world_size": world_size},
398+
common_kwargs={
399+
"ip": ip,
400+
"port": port,
401+
"world_size": world_size,
402+
"train_world_size": train_world_size,
403+
},
399404
)
400405

401406
# this function should co-work with lm_policy, so we should wait for all futures to complete outside

nemo_rl/models/generation/vllm/vllm_worker.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,12 @@ def post_init(self):
477477
self.vllm_device_ids = self.report_device_id()
478478

479479
def init_collective(
480-
self, rank_prefix: int, ip: str, port: int, world_size: int
480+
self,
481+
rank_prefix: int,
482+
ip: str,
483+
port: int,
484+
world_size: int,
485+
train_world_size: int,
481486
) -> None:
482487
self.llm.collective_rpc(
483488
"init_collective",
@@ -486,6 +491,7 @@ def init_collective(
486491
ip,
487492
port,
488493
world_size,
494+
train_world_size,
489495
),
490496
)
491497

nemo_rl/models/generation/vllm/vllm_worker_async.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,12 @@ def _setup_vllm_server(self) -> "tuple[threading.Thread, str, uvicorn.Server]":
393393
return thread, base_url, server
394394

395395
async def init_collective_async(
396-
self, rank_prefix: int, ip: str, port: int, world_size: int
396+
self,
397+
rank_prefix: int,
398+
ip: str,
399+
port: int,
400+
world_size: int,
401+
train_world_size: int,
397402
) -> None:
398403
await self.llm.collective_rpc(
399404
"init_collective",
@@ -402,6 +407,7 @@ async def init_collective_async(
402407
ip,
403408
port,
404409
world_size,
410+
train_world_size,
405411
),
406412
)
407413

nemo_rl/models/policy/dtensor_policy_worker.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -501,17 +501,18 @@ def train_context(cp_context: Optional[Generator[None, None, None]] = None):
501501

502502
yield
503503

504-
def init_collective(self, ip: str, port: int, world_size: int) -> None:
504+
def init_collective(
505+
self, ip: str, port: int, world_size: int, *, train_world_size: int
506+
) -> None:
505507
"""Initialize the collective communication."""
506508
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
507509
from vllm.distributed.utils import StatelessProcessGroup
508510

509-
if self.rank == 0:
510-
pg = StatelessProcessGroup.create(
511-
host=ip, port=port, rank=0, world_size=world_size
512-
)
513-
device = torch.cuda.current_device()
514-
self.model_update_group = PyNcclCommunicator(pg, device=device)
511+
pg = StatelessProcessGroup.create(
512+
host=ip, port=port, rank=self.rank, world_size=world_size
513+
)
514+
device = torch.cuda.current_device()
515+
self.model_update_group = PyNcclCommunicator(pg, device=device)
515516

516517
def is_alive(self) -> bool:
517518
return True
@@ -1808,9 +1809,8 @@ def broadcast_weights_for_collective(self) -> None:
18081809
for _, tensor in self.model.state_dict().items():
18091810
if isinstance(tensor, DTensor):
18101811
tensor = tensor.full_tensor()
1811-
if self.rank == 0:
1812-
tensor = tensor.to(self.dtype, non_blocking=True)
1813-
self.model_update_group.broadcast(tensor.data, src=0)
1812+
tensor = tensor.to(self.dtype, non_blocking=True)
1813+
self.model_update_group.broadcast(tensor.data, src=0)
18141814

18151815
# Manually move model to cpu for cpu offload case
18161816
# cpu offload needs model on CPU before model forward

nemo_rl/models/policy/dtensor_policy_worker_v2.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -459,17 +459,17 @@ def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor:
459459
logits.div_(self.cfg["generation"]["temperature"])
460460
return logits
461461

462-
def init_collective(self, ip: str, port: int, world_size: int) -> None:
463-
"""Initialize the collective communication."""
462+
def init_collective(
463+
self, ip: str, port: int, world_size: int, *, train_world_size: int
464+
) -> None:
464465
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
465466
from vllm.distributed.utils import StatelessProcessGroup
466467

467-
if self.rank == 0:
468-
pg = StatelessProcessGroup.create(
469-
host=ip, port=port, rank=0, world_size=world_size
470-
)
471-
device = torch.cuda.current_device()
472-
self.model_update_group = PyNcclCommunicator(pg, device=device)
468+
pg = StatelessProcessGroup.create(
469+
host=ip, port=port, rank=self.rank, world_size=world_size
470+
)
471+
device = torch.cuda.current_device()
472+
self.model_update_group = PyNcclCommunicator(pg, device=device)
473473

474474
def is_alive(self) -> bool:
475475
return True
@@ -1770,9 +1770,8 @@ def broadcast_weights_for_collective(self) -> None:
17701770
for _, tensor in self.model.state_dict().items():
17711771
if isinstance(tensor, DTensor):
17721772
tensor = tensor.full_tensor()
1773-
if self.rank == 0:
1774-
tensor = tensor.to(self.dtype, non_blocking=True)
1775-
self.model_update_group.broadcast(tensor.data, src=0)
1773+
tensor = tensor.to(self.dtype, non_blocking=True)
1774+
self.model_update_group.broadcast(tensor.data, src=0)
17761775

17771776
# Manually move model to cpu for cpu offload case
17781777
# cpu offload needs model on CPU before model forward

nemo_rl/models/policy/interfaces.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def shutdown(self) -> bool:
141141
class ColocatablePolicyInterface(PolicyInterface):
142142
@abstractmethod
143143
def init_collective(
144-
self, ip: str, port: int, world_size: int
144+
self, ip: str, port: int, world_size: int, *, train_world_size: int
145145
) -> list[ray.ObjectRef]:
146146
pass
147147

nemo_rl/models/policy/lm_policy.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,11 +234,15 @@ def __init__(
234234
self.cfg = config
235235

236236
def init_collective(
237-
self, ip: str, port: int, world_size: int
237+
self, ip: str, port: int, world_size: int, *, train_world_size: int
238238
) -> list[ray.ObjectRef]:
239239
"""Initialize the collective communication."""
240240
futures = self.worker_group.run_all_workers_single_data(
241-
"init_collective", ip=ip, port=port, world_size=world_size
241+
"init_collective",
242+
ip=ip,
243+
port=port,
244+
world_size=world_size,
245+
train_world_size=train_world_size,
242246
)
243247
# this function should co-work with vllm, so we should wait for all futures to complete outside
244248
return futures

0 commit comments

Comments
 (0)