Skip to content

Commit 331faf9

Browse files
ruisearch42epwalsh
authored andcommitted
Introduce RayPPCommunicator for ray-based PP (vllm-project#21660)
Signed-off-by: Rui Qiao <[email protected]>
1 parent aa913ea commit 331faf9

File tree

3 files changed

+280
-0
lines changed

3 files changed

+280
-0
lines changed
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import uuid
4+
from typing import Any, Optional
5+
6+
import ray
7+
import torch
8+
from ray.exceptions import RayChannelError
9+
from ray.experimental.channel.communicator import (Communicator,
10+
TorchTensorAllocator)
11+
from torch.distributed import ReduceOp
12+
13+
from vllm.distributed.device_communicators.base_device_communicator import (
14+
DeviceCommunicatorBase)
15+
from vllm.distributed.parallel_state import get_pp_group
16+
from vllm.logger import init_logger
17+
from vllm.utils import current_stream
18+
19+
logger = init_logger(__name__)
20+
21+
22+
class RayPPCommunicator(Communicator):
23+
"""
24+
Communicator to be used for pipeline parallelism in Ray Compiled Graph.
25+
This is wraps around the vLLM _PP GroupCoordinator.
26+
27+
This class is not thread-safe.
28+
"""
29+
30+
_comm: Optional[DeviceCommunicatorBase]
31+
32+
def __init__(
33+
self,
34+
world_size: int,
35+
comm_id: Any,
36+
rank: Optional[int],
37+
actor_handles: list["ray.actor.ActorHandle"],
38+
cuda_stream: Optional[torch.cuda.Stream],
39+
use_communication_streams: bool = False,
40+
):
41+
"""
42+
Initialize a RayPPCommunicator that can be used to communicate with
43+
other Ray Compiled Graph actors for pipeline parallelism.
44+
45+
Args:
46+
world_size: The number of participating actors.
47+
comm_id: A unique communicator ID. This is just to conform with
48+
the Ray Communicator API and is not used.
49+
rank: The rank of this actor. If None, then the caller is not a
50+
participant of the RayPPCommunicator group (e.g., the Ray
51+
driver).
52+
actor_handles: A list of actor handles.
53+
cuda_stream: A CUDA stream to dispatch communication ops to. This
54+
is not supported.
55+
use_communication_streams: Whether to use communication streams.
56+
This is not supported.
57+
"""
58+
self._world_size = world_size
59+
self._rank: Optional[int] = None
60+
self._actor_handles = actor_handles
61+
if use_communication_streams:
62+
raise NotImplementedError(
63+
"use_communication_streams is not supported")
64+
if cuda_stream is not None and cuda_stream != current_stream():
65+
raise ValueError(
66+
"cuda_stream other than the current stream is not supported")
67+
68+
if rank is not None:
69+
# Rank is not None, this is Ray worker
70+
assert ray.get_gpu_ids(), "RayPPCommunicator has no GPUs assigned"
71+
72+
self._comm = get_pp_group().device_communicator
73+
74+
# Since we wrap around the vLLM _PP communicator, we use
75+
# the rank from the vLLM communicator, and ignore the rank
76+
# passed in from Ray.
77+
# TODO(rui): refactor the Ray Communicator API so that
78+
# it also supports no rank passed in.
79+
self._rank = self._comm.rank_in_group
80+
81+
self._build_actor_rank_mapping()
82+
else:
83+
# Rank is None, this is Ray driver
84+
self._comm = None
85+
86+
self._closed = False
87+
88+
def _build_actor_rank_mapping(self):
89+
"""
90+
Use collective communication to build a mapping from actor IDs to ranks.
91+
This should be called once during initialization.
92+
"""
93+
if self._comm is None:
94+
return {}
95+
96+
current_actor = ray.get_runtime_context().current_actor
97+
actor_id_str = current_actor._actor_id.hex()
98+
99+
# Ray actor IDs are 32-character hex strings (128 bits)
100+
ACTOR_ID_LEN = 32
101+
actor_id_bytes = actor_id_str.encode('utf-8')
102+
assert len(
103+
actor_id_bytes
104+
) == ACTOR_ID_LEN, f"Unexpected actor ID length: {len(actor_id_bytes)}"
105+
106+
actor_id_tensor = torch.frombuffer(
107+
actor_id_bytes, dtype=torch.uint8).to(self._comm.device)
108+
109+
# All-gather full actor IDs from all actors
110+
gathered_ids = self._comm.all_gather(actor_id_tensor, dim=0)
111+
112+
# Build mapping: actor_id -> device_comm_rank
113+
self._actor_id_to_rank = {}
114+
for rank in range(self._world_size):
115+
start_idx = rank * ACTOR_ID_LEN
116+
end_idx = (rank + 1) * ACTOR_ID_LEN
117+
actor_bytes = gathered_ids[start_idx:end_idx].cpu().numpy(
118+
).tobytes()
119+
actor_id = actor_bytes.decode('utf-8')
120+
self._actor_id_to_rank[actor_id] = rank
121+
122+
def initialize(self, rank: int) -> None:
123+
# No additional initialization is needed.
124+
pass
125+
126+
def get_actor_handles(self) -> list["ray.actor.ActorHandle"]:
127+
return self._actor_handles
128+
129+
def get_rank(self, actor: ray.actor.ActorHandle) -> int:
130+
"""
131+
Return the given actor's rank using device communicator collective ops.
132+
"""
133+
assert hasattr(self, '_actor_id_to_rank'), (
134+
"Actor rank mapping not built. "
135+
"This should have been done during initialization.")
136+
137+
actor_id_str = actor._actor_id.hex()
138+
139+
if actor_id_str in self._actor_id_to_rank:
140+
return self._actor_id_to_rank[actor_id_str] # type: ignore
141+
else:
142+
raise ValueError(f"Actor {actor} not found in communicator group")
143+
144+
def get_self_rank(self) -> Optional[int]:
145+
"""
146+
Return this actor's rank.
147+
"""
148+
return self._rank
149+
150+
def get_world_size(self) -> int:
151+
"""
152+
Return the number of ranks in the RayPPCommunicator group.
153+
"""
154+
return self._world_size
155+
156+
def send(self, buf: "torch.Tensor", peer_rank: int) -> None:
157+
"""
158+
Send a torch.Tensor to a peer.
159+
160+
This returns when the send kernel has been queued, but the kernel may
161+
not have completed. Therefore, the caller should ensure that there are
162+
no concurrent writes to the sent `buf` until the send has finished.
163+
That is, either all writes should be submitted on the current stream
164+
(self._cuda_stream) or, if on a different stream, that stream should
165+
synchronize with the current stream.
166+
167+
Args:
168+
buf: The torch.Tensor to send. It should already be on this
169+
actor's default device.
170+
peer_rank: The rank of the actor to send to.
171+
"""
172+
if self._closed:
173+
raise RayChannelError("RayPPCommunicator has been destroyed.")
174+
175+
assert self._comm is not None
176+
self._comm.send(buf, peer_rank)
177+
178+
def recv(
179+
self,
180+
shape: tuple[int],
181+
dtype: "torch.dtype",
182+
peer_rank: int,
183+
allocator: TorchTensorAllocator,
184+
) -> "torch.Tensor":
185+
"""
186+
Receive a torch.Tensor from a peer and synchronize the current stream.
187+
188+
After this call returns, the receive buffer is safe to read from from
189+
any stream. An RayChannelError will be raised if an error occurred
190+
(e.g., remote actor died), and the buffer is not safe to read.
191+
192+
Args:
193+
shape: The shape of the tensor to receive.
194+
dtype: The dtype of the tensor to receive.
195+
peer_rank: The rank of the actor to receive from.
196+
allocator: The allocator to use to create the received tensor.
197+
This is ignored for this implementation.
198+
"""
199+
if self._closed:
200+
raise RayChannelError("RayPPCommunicator has been destroyed.")
201+
202+
assert self._comm is not None
203+
size = torch.Size(shape)
204+
buf = self._comm.recv(size, dtype, src=peer_rank)
205+
206+
# Buffer values are undefined if NCCL ops are aborted. Therefore, we
207+
# need to synchronize here and check that the channel is still
208+
# open to ensure that the receive buffer is valid.
209+
# TODO(swang): Avoid CUDA synchronization.
210+
current_stream().synchronize()
211+
212+
if self._closed:
213+
raise RayChannelError("RayPPCommunicator has been destroyed.")
214+
return buf
215+
216+
def allgather(
217+
self,
218+
send_buf: "torch.Tensor",
219+
recv_buf: "torch.Tensor",
220+
):
221+
raise NotImplementedError("allgather is not supported")
222+
223+
def allreduce(
224+
self,
225+
send_buf: "torch.Tensor",
226+
recv_buf: "torch.Tensor",
227+
op: ReduceOp = ReduceOp.SUM,
228+
):
229+
raise NotImplementedError("allreduce is not supported")
230+
231+
def reducescatter(
232+
self,
233+
send_buf: "torch.Tensor",
234+
recv_buf: "torch.Tensor",
235+
op: ReduceOp = ReduceOp.SUM,
236+
):
237+
raise NotImplementedError("reducescatter is not supported")
238+
239+
@property
240+
def recv_stream(self):
241+
return torch.cuda.StreamContext(current_stream())
242+
243+
@property
244+
def send_stream(self):
245+
return torch.cuda.StreamContext(current_stream())
246+
247+
def destroy(self) -> None:
248+
# Just sets a flag, vLLM manages the lifecycle of the underlying
249+
# _PP GroupCoordinator.
250+
self._closed = True
251+
252+
def get_transport_name(self) -> str:
253+
return "nccl"
254+
255+
@classmethod
256+
def generate_communicator_id(cls) -> Any:
257+
return uuid.uuid4()

vllm/envs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
VLLM_USE_RAY_COMPILED_DAG: bool = False
5656
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "auto"
5757
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
58+
VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True
5859
VLLM_XLA_USE_SPMD: bool = False
5960
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
6061
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
@@ -498,6 +499,13 @@ def get_vllm_port() -> Optional[int]:
498499
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM", "0"))
499500
),
500501

502+
# If the env var is set, it uses a Ray Communicator wrapping
503+
# vLLM's pipeline parallelism communicator to interact with Ray's
504+
# Compiled Graph. Otherwise, it uses Ray's NCCL communicator.
505+
# This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set.
506+
"VLLM_USE_RAY_WRAPPED_PP_COMM":
507+
lambda: bool(int(os.getenv("VLLM_USE_RAY_WRAPPED_PP_COMM", "1"))),
508+
501509
# Use dedicated multiprocess context for workers.
502510
# Both spawn and fork work
503511
"VLLM_WORKER_MULTIPROC_METHOD":

vllm/executor/ray_distributed_executor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,21 @@ def _compiled_ray_dag(self, enable_asyncio: bool):
608608

609609
forward_dag = MultiOutputNode(outputs)
610610

611+
if envs.VLLM_USE_RAY_WRAPPED_PP_COMM:
612+
from ray.experimental.channel.accelerator_context import (
613+
register_accelerator_context)
614+
615+
from vllm.distributed.device_communicators.ray_communicator import (
616+
RayPPCommunicator)
617+
register_accelerator_context(torch_module_name="cuda",
618+
communicator_cls=RayPPCommunicator)
619+
logger.info("Using RayPPCommunicator "
620+
"(which wraps vLLM _PP GroupCoordinator) "
621+
"for Ray Compiled Graph communication.")
622+
else:
623+
logger.info("Using Ray's NCCL communicator for "
624+
"Ray Compiled Graph communication.")
625+
611626
return forward_dag.experimental_compile(
612627
enable_asyncio=enable_asyncio,
613628
_overlap_gpu_communication=envs.

0 commit comments

Comments
 (0)