Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions skyrl-train/skyrl_train/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@ def get_free_port():
return sock.getsockname()[1]


# Reference: https://github.com/vllm-project/vllm/blob/196cdc3224112df7f68c901fe4c5314875a65be8/examples/offline_inference/rlhf.py
def stateless_init_process_group(master_address, master_port, rank, world_size, device):
"""Uses vLLM's `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
"""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup

pg = StatelessProcessGroup.create(host=master_address, port=master_port, rank=rank, world_size=world_size)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl


# Copy from pytorch to allow creating multiple main groups.
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
Expand Down
28 changes: 19 additions & 9 deletions skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)
from vllm.lora.request import LoRARequest
from torch.distributed import destroy_process_group
from skyrl_train.distributed.utils import init_custom_process_group
from skyrl_train.distributed.utils import init_custom_process_group, stateless_init_process_group
from uuid import uuid4
import warnings
from skyrl_train.inference_engines.base import (
Expand Down Expand Up @@ -101,13 +101,18 @@ def init_weight_update_communicator(
f"torch.distributed.get_rank(): {torch.distributed.get_rank()}, rank_offset: {rank_offset}, rank: {rank}, world_size: {world_size}, group_name: {group_name}"
)

self._model_update_group = init_custom_process_group(
backend=backend,
init_method=get_tcp_url(master_address, master_port),
world_size=world_size,
rank=rank,
group_name=group_name,
)
if backend == "nccl":
self._model_update_group = stateless_init_process_group(
master_address, master_port, rank, world_size, torch.device(f"cuda:{torch.cuda.current_device()}")
)
else:
self._model_update_group = init_custom_process_group(
backend=backend,
init_method=get_tcp_url(master_address, master_port),
world_size=world_size,
rank=rank,
group_name=group_name,
)
logger.info(
f"init_weight_update_communicator: master_address={master_address}, master_port={master_port}, ",
f"rank={rank}, world_size={world_size}, group_name={group_name}",
Expand Down Expand Up @@ -667,11 +672,16 @@ def receive_weights(self, request: NamedWeightsUpdateRequest) -> Iterator[Tuple[

def _receive_broadcast(self, request: NamedWeightsUpdateRequest) -> Iterator[Tuple[str, torch.Tensor]]:
"""Receive weights via torch.distributed.broadcast."""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator

for name, dtype_str, shape in zip(request["names"], request["dtypes"], request["shapes"]):
dtype = str_to_torch_dtype(dtype_str)
assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}"
weight = torch.empty(shape, dtype=dtype, device="cuda")
torch.distributed.broadcast(weight, 0, group=self.model_update_group)
if isinstance(self.model_update_group, PyNcclCommunicator):
self.model_update_group.broadcast(weight, src=0, stream=torch.cuda.current_stream())
else:
torch.distributed.broadcast(weight, 0, group=self.model_update_group)
yield name, weight

def _receive_ipc(self, request: NamedWeightsUpdateRequest) -> Iterator[Tuple[str, torch.Tensor]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,10 @@ async def broadcast_to_inference_engines(self, inference_engine_client):
# Broadcast tensor
def broadcast_tensor(tensor):
if torch.distributed.get_rank() == 0:
torch.distributed.broadcast(tensor.data, 0, group=self._model_update_group)
if self.cfg.generator.weight_sync_backend == "nccl" and self.cfg.generator.backend == "vllm":
self._model_update_group.broadcast(tensor.data, src=0, stream=torch.cuda.current_stream())
else:
torch.distributed.broadcast(tensor.data, 0, group=self._model_update_group)

await asyncio.to_thread(broadcast_tensor, tensor)
if torch.distributed.get_rank() == 0:
Expand Down
5 changes: 4 additions & 1 deletion skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,10 @@ async def broadcast_to_inference_engines(self, inference_engine_client):
# Broadcast tensor
def broadcast_tensor(tensor):
if torch.distributed.get_rank() == 0:
torch.distributed.broadcast(tensor.data, 0, group=self._model_update_group)
if self.cfg.generator.weight_sync_backend == "nccl" and self.cfg.generator.backend == "vllm":
self._model_update_group.broadcast(tensor.data, src=0, stream=torch.cuda.current_stream())
else:
torch.distributed.broadcast(tensor.data, 0, group=self._model_update_group)

await asyncio.to_thread(broadcast_tensor, tensor)
if torch.distributed.get_rank() == 0:
Expand Down
26 changes: 23 additions & 3 deletions skyrl-train/skyrl_train/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from transformers import PreTrainedModel
from loguru import logger
from skyrl_train.distributed.ulysses import set_ulysses_sequence_parallel_group, apply_monkey_patch
from skyrl_train.distributed.utils import init_custom_process_group
from skyrl_train.distributed.utils import init_custom_process_group, stateless_init_process_group
from skyrl_train.utils.ppo_utils import PolicyLossRegistry, ppo_critic_loss, compute_approx_kl
from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics
from skyrl_train.dataset.replay_buffer import Experience
Expand Down Expand Up @@ -297,11 +297,31 @@ async def init_weight_sync_state(self, inference_engine_client: InferenceEngineC
)
)

def _init_process_group(backend, master_addr, master_port, world_size, rank, group_name):
if backend == "nccl" and self.cfg.generator.backend == "vllm":
model_update_group = stateless_init_process_group(
master_address=master_addr,
master_port=master_port,
world_size=world_size,
rank=rank,
device=torch.device(f"cuda:{torch.cuda.current_device()}"),
)
else:
model_update_group = init_custom_process_group(
backend=backend,
init_method=get_tcp_url(master_addr, master_port),
rank=rank,
world_size=world_size,
group_name=group_name,
)
return model_update_group

tasks.append(
asyncio.to_thread(
init_custom_process_group,
_init_process_group,
backend=backend,
init_method=get_tcp_url(master_addr, master_port),
master_addr=master_addr,
master_port=master_port,
world_size=world_size,
rank=0,
group_name=group_name,
Expand Down