|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +import uuid |
| 4 | +from typing import Any, Optional |
| 5 | + |
| 6 | +import ray |
| 7 | +import torch |
| 8 | +from ray.exceptions import RayChannelError |
| 9 | +from ray.experimental.channel.communicator import (Communicator, |
| 10 | + TorchTensorAllocator) |
| 11 | +from torch.distributed import ReduceOp |
| 12 | + |
| 13 | +from vllm.distributed.device_communicators.base_device_communicator import ( |
| 14 | + DeviceCommunicatorBase) |
| 15 | +from vllm.distributed.parallel_state import get_pp_group |
| 16 | +from vllm.logger import init_logger |
| 17 | +from vllm.utils import current_stream |
| 18 | + |
| 19 | +logger = init_logger(__name__) |
| 20 | + |
| 21 | + |
| 22 | +class RayPPCommunicator(Communicator): |
| 23 | + """ |
| 24 | + Communicator to be used for pipeline parallelism in Ray Compiled Graph. |
| 25 | + This is wraps around the vLLM _PP GroupCoordinator. |
| 26 | +
|
| 27 | + This class is not thread-safe. |
| 28 | + """ |
| 29 | + |
| 30 | + _comm: Optional[DeviceCommunicatorBase] |
| 31 | + |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + world_size: int, |
| 35 | + comm_id: Any, |
| 36 | + rank: Optional[int], |
| 37 | + actor_handles: list["ray.actor.ActorHandle"], |
| 38 | + cuda_stream: Optional[torch.cuda.Stream], |
| 39 | + use_communication_streams: bool = False, |
| 40 | + ): |
| 41 | + """ |
| 42 | + Initialize a RayPPCommunicator that can be used to communicate with |
| 43 | + other Ray Compiled Graph actors for pipeline parallelism. |
| 44 | +
|
| 45 | + Args: |
| 46 | + world_size: The number of participating actors. |
| 47 | + comm_id: A unique communicator ID. This is just to conform with |
| 48 | + the Ray Communicator API and is not used. |
| 49 | + rank: The rank of this actor. If None, then the caller is not a |
| 50 | + participant of the RayPPCommunicator group (e.g., the Ray |
| 51 | + driver). |
| 52 | + actor_handles: A list of actor handles. |
| 53 | + cuda_stream: A CUDA stream to dispatch communication ops to. This |
| 54 | + is not supported. |
| 55 | + use_communication_streams: Whether to use communication streams. |
| 56 | + This is not supported. |
| 57 | + """ |
| 58 | + self._world_size = world_size |
| 59 | + self._rank: Optional[int] = None |
| 60 | + self._actor_handles = actor_handles |
| 61 | + if use_communication_streams: |
| 62 | + raise NotImplementedError( |
| 63 | + "use_communication_streams is not supported") |
| 64 | + if cuda_stream is not None and cuda_stream != current_stream(): |
| 65 | + raise ValueError( |
| 66 | + "cuda_stream other than the current stream is not supported") |
| 67 | + |
| 68 | + if rank is not None: |
| 69 | + # Rank is not None, this is Ray worker |
| 70 | + assert ray.get_gpu_ids(), "RayPPCommunicator has no GPUs assigned" |
| 71 | + |
| 72 | + self._comm = get_pp_group().device_communicator |
| 73 | + |
| 74 | + # Since we wrap around the vLLM _PP communicator, we use |
| 75 | + # the rank from the vLLM communicator, and ignore the rank |
| 76 | + # passed in from Ray. |
| 77 | + # TODO(rui): refactor the Ray Communicator API so that |
| 78 | + # it also supports no rank passed in. |
| 79 | + self._rank = self._comm.rank_in_group |
| 80 | + |
| 81 | + self._build_actor_rank_mapping() |
| 82 | + else: |
| 83 | + # Rank is None, this is Ray driver |
| 84 | + self._comm = None |
| 85 | + |
| 86 | + self._closed = False |
| 87 | + |
| 88 | + def _build_actor_rank_mapping(self): |
| 89 | + """ |
| 90 | + Use collective communication to build a mapping from actor IDs to ranks. |
| 91 | + This should be called once during initialization. |
| 92 | + """ |
| 93 | + if self._comm is None: |
| 94 | + return {} |
| 95 | + |
| 96 | + current_actor = ray.get_runtime_context().current_actor |
| 97 | + actor_id_str = current_actor._actor_id.hex() |
| 98 | + |
| 99 | + # Ray actor IDs are 32-character hex strings (128 bits) |
| 100 | + ACTOR_ID_LEN = 32 |
| 101 | + actor_id_bytes = actor_id_str.encode('utf-8') |
| 102 | + assert len( |
| 103 | + actor_id_bytes |
| 104 | + ) == ACTOR_ID_LEN, f"Unexpected actor ID length: {len(actor_id_bytes)}" |
| 105 | + |
| 106 | + actor_id_tensor = torch.frombuffer( |
| 107 | + actor_id_bytes, dtype=torch.uint8).to(self._comm.device) |
| 108 | + |
| 109 | + # All-gather full actor IDs from all actors |
| 110 | + gathered_ids = self._comm.all_gather(actor_id_tensor, dim=0) |
| 111 | + |
| 112 | + # Build mapping: actor_id -> device_comm_rank |
| 113 | + self._actor_id_to_rank = {} |
| 114 | + for rank in range(self._world_size): |
| 115 | + start_idx = rank * ACTOR_ID_LEN |
| 116 | + end_idx = (rank + 1) * ACTOR_ID_LEN |
| 117 | + actor_bytes = gathered_ids[start_idx:end_idx].cpu().numpy( |
| 118 | + ).tobytes() |
| 119 | + actor_id = actor_bytes.decode('utf-8') |
| 120 | + self._actor_id_to_rank[actor_id] = rank |
| 121 | + |
| 122 | + def initialize(self, rank: int) -> None: |
| 123 | + # No additional initialization is needed. |
| 124 | + pass |
| 125 | + |
| 126 | + def get_actor_handles(self) -> list["ray.actor.ActorHandle"]: |
| 127 | + return self._actor_handles |
| 128 | + |
| 129 | + def get_rank(self, actor: ray.actor.ActorHandle) -> int: |
| 130 | + """ |
| 131 | + Return the given actor's rank using device communicator collective ops. |
| 132 | + """ |
| 133 | + assert hasattr(self, '_actor_id_to_rank'), ( |
| 134 | + "Actor rank mapping not built. " |
| 135 | + "This should have been done during initialization.") |
| 136 | + |
| 137 | + actor_id_str = actor._actor_id.hex() |
| 138 | + |
| 139 | + if actor_id_str in self._actor_id_to_rank: |
| 140 | + return self._actor_id_to_rank[actor_id_str] # type: ignore |
| 141 | + else: |
| 142 | + raise ValueError(f"Actor {actor} not found in communicator group") |
| 143 | + |
| 144 | + def get_self_rank(self) -> Optional[int]: |
| 145 | + """ |
| 146 | + Return this actor's rank. |
| 147 | + """ |
| 148 | + return self._rank |
| 149 | + |
| 150 | + def get_world_size(self) -> int: |
| 151 | + """ |
| 152 | + Return the number of ranks in the RayPPCommunicator group. |
| 153 | + """ |
| 154 | + return self._world_size |
| 155 | + |
| 156 | + def send(self, buf: "torch.Tensor", peer_rank: int) -> None: |
| 157 | + """ |
| 158 | + Send a torch.Tensor to a peer. |
| 159 | +
|
| 160 | + This returns when the send kernel has been queued, but the kernel may |
| 161 | + not have completed. Therefore, the caller should ensure that there are |
| 162 | + no concurrent writes to the sent `buf` until the send has finished. |
| 163 | + That is, either all writes should be submitted on the current stream |
| 164 | + (self._cuda_stream) or, if on a different stream, that stream should |
| 165 | + synchronize with the current stream. |
| 166 | +
|
| 167 | + Args: |
| 168 | + buf: The torch.Tensor to send. It should already be on this |
| 169 | + actor's default device. |
| 170 | + peer_rank: The rank of the actor to send to. |
| 171 | + """ |
| 172 | + if self._closed: |
| 173 | + raise RayChannelError("RayPPCommunicator has been destroyed.") |
| 174 | + |
| 175 | + assert self._comm is not None |
| 176 | + self._comm.send(buf, peer_rank) |
| 177 | + |
| 178 | + def recv( |
| 179 | + self, |
| 180 | + shape: tuple[int], |
| 181 | + dtype: "torch.dtype", |
| 182 | + peer_rank: int, |
| 183 | + allocator: TorchTensorAllocator, |
| 184 | + ) -> "torch.Tensor": |
| 185 | + """ |
| 186 | + Receive a torch.Tensor from a peer and synchronize the current stream. |
| 187 | +
|
| 188 | + After this call returns, the receive buffer is safe to read from from |
| 189 | + any stream. An RayChannelError will be raised if an error occurred |
| 190 | + (e.g., remote actor died), and the buffer is not safe to read. |
| 191 | +
|
| 192 | + Args: |
| 193 | + shape: The shape of the tensor to receive. |
| 194 | + dtype: The dtype of the tensor to receive. |
| 195 | + peer_rank: The rank of the actor to receive from. |
| 196 | + allocator: The allocator to use to create the received tensor. |
| 197 | + This is ignored for this implementation. |
| 198 | + """ |
| 199 | + if self._closed: |
| 200 | + raise RayChannelError("RayPPCommunicator has been destroyed.") |
| 201 | + |
| 202 | + assert self._comm is not None |
| 203 | + size = torch.Size(shape) |
| 204 | + buf = self._comm.recv(size, dtype, src=peer_rank) |
| 205 | + |
| 206 | + # Buffer values are undefined if NCCL ops are aborted. Therefore, we |
| 207 | + # need to synchronize here and check that the channel is still |
| 208 | + # open to ensure that the receive buffer is valid. |
| 209 | + # TODO(swang): Avoid CUDA synchronization. |
| 210 | + current_stream().synchronize() |
| 211 | + |
| 212 | + if self._closed: |
| 213 | + raise RayChannelError("RayPPCommunicator has been destroyed.") |
| 214 | + return buf |
| 215 | + |
| 216 | + def allgather( |
| 217 | + self, |
| 218 | + send_buf: "torch.Tensor", |
| 219 | + recv_buf: "torch.Tensor", |
| 220 | + ): |
| 221 | + raise NotImplementedError("allgather is not supported") |
| 222 | + |
| 223 | + def allreduce( |
| 224 | + self, |
| 225 | + send_buf: "torch.Tensor", |
| 226 | + recv_buf: "torch.Tensor", |
| 227 | + op: ReduceOp = ReduceOp.SUM, |
| 228 | + ): |
| 229 | + raise NotImplementedError("allreduce is not supported") |
| 230 | + |
| 231 | + def reducescatter( |
| 232 | + self, |
| 233 | + send_buf: "torch.Tensor", |
| 234 | + recv_buf: "torch.Tensor", |
| 235 | + op: ReduceOp = ReduceOp.SUM, |
| 236 | + ): |
| 237 | + raise NotImplementedError("reducescatter is not supported") |
| 238 | + |
| 239 | + @property |
| 240 | + def recv_stream(self): |
| 241 | + return torch.cuda.StreamContext(current_stream()) |
| 242 | + |
| 243 | + @property |
| 244 | + def send_stream(self): |
| 245 | + return torch.cuda.StreamContext(current_stream()) |
| 246 | + |
| 247 | + def destroy(self) -> None: |
| 248 | + # Just sets a flag, vLLM manages the lifecycle of the underlying |
| 249 | + # _PP GroupCoordinator. |
| 250 | + self._closed = True |
| 251 | + |
| 252 | + def get_transport_name(self) -> str: |
| 253 | + return "nccl" |
| 254 | + |
| 255 | + @classmethod |
| 256 | + def generate_communicator_id(cls) -> Any: |
| 257 | + return uuid.uuid4() |
0 commit comments