Skip to content

Commit 4d34802

Browse files
author
R2-Y
committed
NixlConnector + PP + PD compatible with Ray
1 parent 09eaed3 commit 4d34802

File tree

8 files changed

+211
-51
lines changed

8 files changed

+211
-51
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
2424
CopyBlocksOp, KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
2525
from vllm.distributed.parallel_state import (
26-
get_pipeline_model_parallel_rank, get_tensor_model_parallel_rank,
27-
get_tensor_model_parallel_world_size, get_tp_group)
26+
get_rank, get_pipeline_model_parallel_rank,
27+
get_tensor_model_parallel_rank,get_tensor_model_parallel_world_size,
28+
get_tp_group)
2829
from vllm.distributed.utils import divide
2930
from vllm.forward_context import ForwardContext
3031
from vllm.logger import init_logger
@@ -113,8 +114,7 @@ def add_new_req(
113114
remote_port=kv_transfer_params["remote_port"],
114115
# P workers don't need to receive tp_size from proxy here.
115116
tp_size=kv_transfer_params.get("tp_size", 1),
116-
pp_size=kv_transfer_params.get("pp_size", 1)
117-
)
117+
pp_size=kv_transfer_params.get("pp_size", 1))
118118
if save_to_host:
119119
self.reqs_to_save[request_id] = _req
120120
if load_remote_cache:
@@ -223,9 +223,9 @@ def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
223223

224224
def wait_for_save(self):
225225
assert self.connector_worker is not None
226-
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
227226
if self.connector_worker.use_host_buffer and \
228227
self.connector_worker.copy_blocks:
228+
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
229229
self.connector_worker.save_kv_to_host(self._connector_metadata)
230230

231231

@@ -439,24 +439,26 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
439439
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
440440
# Map of engine_id -> {pprank0: {{rank0: agent_name0, rank1: agent_name1}},
441441
# pprank1: {{rank0: agent_name2, rank1: agent_name3..}..}.
442-
self._remote_agents: dict[EngineId, dict[int, dict[int, str]]] = defaultdict(dict)
442+
self._remote_agents: dict[EngineId,
443+
dict[int, dict[int,
444+
str]]] = defaultdict(dict)
443445

444446
# NIXL handshake port.
445447
# NOTE(rob): Within a DP group, each DP rank gets its own
446448
# base port (which is sent in the KVTransferParams).
447-
# Each TP rank listens/queries on the base_port +
448-
# pp_rank * tp_size + tp_rank.
449-
self.pp_rank = get_pipeline_model_parallel_rank()
449+
# Each TP rank listens/queries on the base_port + global_rank.
450+
self.rank_in_dp_group = get_rank()
450451
self.side_channel_port: int = (
451452
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
452453
vllm_config.parallel_config.data_parallel_rank *
453454
vllm_config.parallel_config.tensor_parallel_size *
454455
vllm_config.parallel_config.pipeline_parallel_size +
455-
self.pp_rank * vllm_config.parallel_config.tensor_parallel_size)
456+
self.rank_in_dp_group)
456457

457458
# Metadata.
458459
self.engine_id: EngineId = engine_id
459460
self.tp_rank = get_tensor_model_parallel_rank()
461+
self.pp_rank = get_pipeline_model_parallel_rank()
460462
self.world_size = get_tensor_model_parallel_world_size()
461463
self.tp_group = get_tp_group()
462464
self.num_blocks = 0
@@ -524,7 +526,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
524526
max_workers=1,
525527
thread_name_prefix="vllm-nixl-handshake-initiator")
526528
self._ready_requests = queue.Queue[tuple[ReqId, ReqMeta]]()
527-
self._handshake_futures: dict[EngineId, dict[int, Future[dict[int, str]]]] = {}
529+
self._handshake_futures: dict[EngineId, Future[dict[int, str]]] = {}
528530
# Protects _handshake_futures and _remote_agents.
529531
self._handshake_lock = threading.RLock()
530532

@@ -554,11 +556,30 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
554556
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
555557

556558
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
557-
self.device_id = torch.cuda.current_device()
559+
self.device_id = self._get_current_device_id()
558560
# With heterogeneous TP, P must wait for all assigned D TP workers to
559561
# finish reading before safely freeing the blocks.
560562
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
561563

564+
def _get_current_device_id(self) -> int:
565+
"""Get the current device ID in a platform-agnostic way."""
566+
from vllm.platforms import current_platform
567+
568+
if current_platform.is_cuda_alike():
569+
return torch.cuda.current_device()
570+
elif current_platform.is_tpu():
571+
return get_tensor_model_parallel_rank()
572+
elif current_platform.is_xpu():
573+
try:
574+
import intel_extension_for_pytorch as ipex
575+
return ipex.xpu.current_device()
576+
except ImportError:
577+
return get_tensor_model_parallel_rank()
578+
else:
579+
# For CPU and other platforms
580+
return get_tensor_model_parallel_rank()
581+
582+
562583
def __del__(self):
563584
"""Cleanup background threads on destruction."""
564585
self._handshake_initiation_executor.shutdown(wait=False)
@@ -567,8 +588,7 @@ def __del__(self):
567588

568589
@staticmethod
569590
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
570-
ready_event: threading.Event, base_port: int,
571-
tp_rank: int):
591+
ready_event: threading.Event, base_port: int):
572592
"""Background thread for getting new NIXL handshakes."""
573593
# NOTE(rob): this is a simple implementation. We will move
574594
# to a better approach via HTTP endpoint soon.
@@ -581,7 +601,7 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata,
581601

582602
# Listen for new requests for metadata.
583603
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
584-
path = make_zmq_path("tcp", host, base_port + tp_rank)
604+
path = make_zmq_path("tcp", host, base_port)
585605
logger.debug("Starting listening on path: %s", path)
586606
with zmq_ctx(zmq.ROUTER, path) as sock:
587607
ready_event.set()
@@ -607,13 +627,14 @@ def _nixl_handshake(
607627
# NOTE(rob): we need each rank to have a unique port. This is
608628
# a hack to keep us moving. We will switch when moving to etcd
609629
# or where we have a single ZMQ socket in the scheduler.
610-
611630
# Handshake only with the remote TP rank that current local rank will
612631
# pull from. With homogeneous TP it happens to be the same rank_i.
613632
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
614633
p_remote_tp_rank = self.tp_rank // tp_ratio
615-
p_remote_pp_rank = self.pp_rank # don't support homogeneous PP
616-
path = make_zmq_path("tcp", host, port + remote_tp_size * p_remote_pp_rank + p_remote_tp_rank)
634+
p_remote_pp_rank = self.pp_rank # don't support heterogenous PP
635+
path = make_zmq_path(
636+
"tcp", host,
637+
port + remote_tp_size * p_remote_pp_rank + p_remote_tp_rank)
617638
logger.debug("Querying metadata on path: %s at remote tp rank %s, remote pp rank %s", path,
618639
p_remote_tp_rank, p_remote_pp_rank)
619640

@@ -634,8 +655,10 @@ def _nixl_handshake(
634655
f"received {metadata.engine_id}.")
635656

636657
# Register Remote agent.
637-
remote_agent_name = self.add_remote_agent(metadata, p_remote_tp_rank,
638-
remote_tp_size, p_remote_pp_rank,
658+
remote_agent_name = self.add_remote_agent(metadata,
659+
p_remote_tp_rank,
660+
remote_tp_size,
661+
p_remote_pp_rank,
639662
remote_pp_size)
640663
setup_agent_time = time.perf_counter()
641664
logger.debug("NIXL handshake: add agent took: %s",
@@ -677,11 +700,11 @@ def _background_nixl_handshake(self, req_id: str,
677700
fut = self._handshake_initiation_executor.submit(
678701
self._nixl_handshake, meta.remote_host, meta.remote_port,
679702
meta.tp_size, meta.pp_size, remote_engine_id)
680-
self._handshake_futures[remote_engine_id] = {self.pp_rank : fut}
703+
self._handshake_futures[remote_engine_id] = fut
681704

682705
def done_callback(f: Future[dict[int, str]], eid=remote_engine_id):
683706
with self._handshake_lock:
684-
del self._handshake_futures[eid][self.pp_rank]
707+
del self._handshake_futures[eid]
685708
try:
686709
self._remote_agents[eid][self.pp_rank] = f.result()
687710
except Exception:
@@ -834,7 +857,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
834857
ready_event = threading.Event()
835858
self._nixl_handshake_listener_t = threading.Thread(
836859
target=self._nixl_handshake_listener,
837-
args=(metadata, ready_event, self.side_channel_port, self.tp_rank),
860+
args=(metadata, ready_event, self.side_channel_port),
838861
daemon=True,
839862
name="nixl_handshake_listener")
840863
self._nixl_handshake_listener_t.start()
@@ -895,8 +918,10 @@ def add_remote_agent(self,
895918
""" # noqa: E501
896919
engine_id = nixl_agent_meta.engine_id
897920
# TODO re-evaluate refreshing for scaling/recovery
898-
if remote_tp_rank in self._remote_agents.get(engine_id, {}).get(remote_pp_rank, {}):
899-
return self._remote_agents[engine_id][remote_pp_rank][remote_tp_rank]
921+
if remote_tp_rank in self._remote_agents.get(engine_id, {}).get(
922+
remote_pp_rank, {}):
923+
return self._remote_agents[engine_id][remote_pp_rank][
924+
remote_tp_rank]
900925

901926
if engine_id not in self._tp_size:
902927
self._tp_size[engine_id] = remote_tp_size
@@ -969,12 +994,13 @@ def add_remote_agent(self,
969994
# self.block_len == remote_block_len//tp_ratio bytes.
970995
addr = base_addr + block_offset + rank_offset
971996
# (addr, len, device id)
972-
# blocks_data.append((addr, self.block_len, remote_tp_rank))
973-
blocks_data.append((addr, self.block_len, remote_tp_rank + remote_pp_rank * remote_tp_size))
997+
blocks_data.append((addr, self.block_len,
998+
remote_tp_rank + remote_pp_rank * remote_tp_size))
974999
logger.debug(
9751000
"Created %s blocks for dst engine %s with remote rank %s and "
976-
"tp local rank %s, device id %s", len(blocks_data), engine_id, remote_tp_rank,
977-
self.tp_rank, remote_tp_rank + remote_pp_rank * remote_tp_size)
1001+
"tp local rank %s, device id %s", len(blocks_data), engine_id,
1002+
remote_tp_rank, self.tp_rank,
1003+
remote_tp_rank + remote_pp_rank * remote_tp_size)
9781004

9791005
# Register with NIXL.
9801006
descs = self.nixl_wrapper.get_xfer_descs(blocks_data,
@@ -1093,10 +1119,8 @@ def _pop_done_transfers(
10931119
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
10941120
if xfer_state == "DONE":
10951121
self.nixl_wrapper.release_xfer_handle(handle)
1096-
logger.info(f"============transfer req_id {req_id} done")
10971122
elif xfer_state == "PROC":
10981123
in_progress = True
1099-
logger.info(f"============transfer req_id {req_id} processing")
11001124
continue
11011125
else:
11021126
raise RuntimeError("Transfer failed with state %s",
@@ -1173,8 +1197,9 @@ def _read_blocks(self, local_block_ids: list[int],
11731197
num_local_blocks = len(local_block_ids)
11741198
if num_local_blocks == 0:
11751199
remote_rank = self.tp_rank // tp_ratio
1176-
remote_pp_rank = self.pp_rank # don't consider heterogeneous PP now
1177-
agent_name = self._remote_agents[dst_engine_id][remote_pp_rank][remote_rank]
1200+
remote_pp_rank = self.pp_rank # don't consider heterogeneous PP now
1201+
agent_name = self._remote_agents[dst_engine_id][remote_pp_rank][
1202+
remote_rank]
11781203
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
11791204
return
11801205

vllm/distributed/parallel_state.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,11 @@ def next_rank(self):
306306
world_size = self.world_size
307307
return self.ranks[(rank_in_group + 1) % world_size]
308308

309+
@property
310+
def current_rank(self):
311+
"""Return the current rank of the process"""
312+
return self.rank
313+
309314
@property
310315
def prev_rank(self):
311316
"""Return the global rank of the process that precedes the caller"""
@@ -1231,6 +1236,14 @@ def get_pipeline_model_parallel_rank():
12311236
return get_pp_group().rank_in_group
12321237

12331238

1239+
def get_rank():
1240+
"""Return rank for the dp group."""
1241+
global _WORLD
1242+
if _WORLD is not None:
1243+
return _WORLD.current_rank
1244+
return 0
1245+
1246+
12341247
def get_node_count() -> int:
12351248
"""Return the total number of nodes in the distributed environment. """
12361249
assert _NODE_COUNT is not None, (

vllm/executor/ray_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,15 @@ def execute_model_ray(
143143
assert not output or not output.req_ids
144144
output = scheduler_output, None
145145
return output
146+
147+
def pull_kvcache_ray(
148+
self,
149+
scheduler_output: "SchedulerOutput") -> "ModelRunnerOutput":
150+
assert self.worker is not None, "Worker is not initialized"
151+
152+
output = self.worker.model_runner.pull_kvcache(
153+
scheduler_output)
154+
return output
146155

147156
def override_env_vars(self, vars: Dict[str, str]):
148157
os.environ.update(vars)

0 commit comments

Comments
 (0)