2323from vllm .distributed .kv_transfer .kv_connector .v1 .base import (
2424 CopyBlocksOp , KVConnectorBase_V1 , KVConnectorMetadata , KVConnectorRole )
2525from vllm .distributed .parallel_state import (
26- get_pipeline_model_parallel_rank , get_tensor_model_parallel_rank ,
27- get_tensor_model_parallel_world_size , get_tp_group )
26+ get_rank , get_pipeline_model_parallel_rank ,
27+ get_tensor_model_parallel_rank ,get_tensor_model_parallel_world_size ,
28+ get_tp_group )
2829from vllm .distributed .utils import divide
2930from vllm .forward_context import ForwardContext
3031from vllm .logger import init_logger
@@ -113,8 +114,7 @@ def add_new_req(
113114 remote_port = kv_transfer_params ["remote_port" ],
114115 # P workers don't need to receive tp_size from proxy here.
115116 tp_size = kv_transfer_params .get ("tp_size" , 1 ),
116- pp_size = kv_transfer_params .get ("pp_size" , 1 )
117- )
117+ pp_size = kv_transfer_params .get ("pp_size" , 1 ))
118118 if save_to_host :
119119 self .reqs_to_save [request_id ] = _req
120120 if load_remote_cache :
@@ -223,9 +223,9 @@ def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
223223
224224 def wait_for_save (self ):
225225 assert self .connector_worker is not None
226- assert isinstance (self ._connector_metadata , NixlConnectorMetadata )
227226 if self .connector_worker .use_host_buffer and \
228227 self .connector_worker .copy_blocks :
228+ assert isinstance (self ._connector_metadata , NixlConnectorMetadata )
229229 self .connector_worker .save_kv_to_host (self ._connector_metadata )
230230
231231
@@ -439,24 +439,26 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
439439 self .nixl_wrapper = NixlWrapper (str (uuid .uuid4 ()), None )
440440 # Map of engine_id -> {pprank0: {{rank0: agent_name0, rank1: agent_name1}},
441441 # pprank1: {{rank0: agent_name2, rank1: agent_name3..}..}.
442- self ._remote_agents : dict [EngineId , dict [int , dict [int , str ]]] = defaultdict (dict )
442+ self ._remote_agents : dict [EngineId ,
443+ dict [int , dict [int ,
444+ str ]]] = defaultdict (dict )
443445
444446 # NIXL handshake port.
445447 # NOTE(rob): Within a DP group, each DP rank gets its own
446448 # base port (which is sent in the KVTransferParams).
447- # Each TP rank listens/queries on the base_port +
448- # pp_rank * tp_size + tp_rank.
449- self .pp_rank = get_pipeline_model_parallel_rank ()
449+ # Each TP rank listens/queries on the base_port + global_rank.
450+ self .rank_in_dp_group = get_rank ()
450451 self .side_channel_port : int = (
451452 envs .VLLM_NIXL_SIDE_CHANNEL_PORT +
452453 vllm_config .parallel_config .data_parallel_rank *
453454 vllm_config .parallel_config .tensor_parallel_size *
454455 vllm_config .parallel_config .pipeline_parallel_size +
455- self .pp_rank * vllm_config . parallel_config . tensor_parallel_size )
456+ self .rank_in_dp_group )
456457
457458 # Metadata.
458459 self .engine_id : EngineId = engine_id
459460 self .tp_rank = get_tensor_model_parallel_rank ()
461+ self .pp_rank = get_pipeline_model_parallel_rank ()
460462 self .world_size = get_tensor_model_parallel_world_size ()
461463 self .tp_group = get_tp_group ()
462464 self .num_blocks = 0
@@ -524,7 +526,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
524526 max_workers = 1 ,
525527 thread_name_prefix = "vllm-nixl-handshake-initiator" )
526528 self ._ready_requests = queue .Queue [tuple [ReqId , ReqMeta ]]()
527- self ._handshake_futures : dict [EngineId , dict [ int , Future [dict [int , str ] ]]] = {}
529+ self ._handshake_futures : dict [EngineId , Future [dict [int , str ]]] = {}
528530 # Protects _handshake_futures and _remote_agents.
529531 self ._handshake_lock = threading .RLock ()
530532
@@ -554,11 +556,30 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
554556 logger .debug ("Detected kv cache layout %s" , self .kv_cache_layout )
555557
556558 self ._tp_size : dict [EngineId , int ] = {self .engine_id : self .world_size }
557- self .device_id = torch . cuda . current_device ()
559+ self .device_id = self . _get_current_device_id ()
558560 # With heterogeneous TP, P must wait for all assigned D TP workers to
559561 # finish reading before safely freeing the blocks.
560562 self .consumer_notification_counts_by_req = defaultdict [ReqId , int ](int )
561563
564+ def _get_current_device_id (self ) -> int :
565+ """Get the current device ID in a platform-agnostic way."""
566+ from vllm .platforms import current_platform
567+
568+ if current_platform .is_cuda_alike ():
569+ return torch .cuda .current_device ()
570+ elif current_platform .is_tpu ():
571+ return get_tensor_model_parallel_rank ()
572+ elif current_platform .is_xpu ():
573+ try :
574+ import intel_extension_for_pytorch as ipex
575+ return ipex .xpu .current_device ()
576+ except ImportError :
577+ return get_tensor_model_parallel_rank ()
578+ else :
579+ # For CPU and other platforms
580+ return get_tensor_model_parallel_rank ()
581+
582+
562583 def __del__ (self ):
563584 """Cleanup background threads on destruction."""
564585 self ._handshake_initiation_executor .shutdown (wait = False )
@@ -567,8 +588,7 @@ def __del__(self):
567588
568589 @staticmethod
569590 def _nixl_handshake_listener (metadata : NixlAgentMetadata ,
570- ready_event : threading .Event , base_port : int ,
571- tp_rank : int ):
591+ ready_event : threading .Event , base_port : int ):
572592 """Background thread for getting new NIXL handshakes."""
573593 # NOTE(rob): this is a simple implementation. We will move
574594 # to a better approach via HTTP endpoint soon.
@@ -581,7 +601,7 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata,
581601
582602 # Listen for new requests for metadata.
583603 host = envs .VLLM_NIXL_SIDE_CHANNEL_HOST
584- path = make_zmq_path ("tcp" , host , base_port + tp_rank )
604+ path = make_zmq_path ("tcp" , host , base_port )
585605 logger .debug ("Starting listening on path: %s" , path )
586606 with zmq_ctx (zmq .ROUTER , path ) as sock :
587607 ready_event .set ()
@@ -607,13 +627,14 @@ def _nixl_handshake(
607627 # NOTE(rob): we need each rank to have a unique port. This is
608628 # a hack to keep us moving. We will switch when moving to etcd
609629 # or where we have a single ZMQ socket in the scheduler.
610-
611630 # Handshake only with the remote TP rank that current local rank will
612631 # pull from. With homogeneous TP it happens to be the same rank_i.
613632 tp_ratio = self ._tp_size [self .engine_id ] // remote_tp_size
614633 p_remote_tp_rank = self .tp_rank // tp_ratio
615- p_remote_pp_rank = self .pp_rank # don't support homogeneous PP
616- path = make_zmq_path ("tcp" , host , port + remote_tp_size * p_remote_pp_rank + p_remote_tp_rank )
634+ p_remote_pp_rank = self .pp_rank # don't support heterogenous PP
635+ path = make_zmq_path (
636+ "tcp" , host ,
637+ port + remote_tp_size * p_remote_pp_rank + p_remote_tp_rank )
617638 logger .debug ("Querying metadata on path: %s at remote tp rank %s, remote pp rank %s" , path ,
618639 p_remote_tp_rank , p_remote_pp_rank )
619640
@@ -634,8 +655,10 @@ def _nixl_handshake(
634655 f"received { metadata .engine_id } ." )
635656
636657 # Register Remote agent.
637- remote_agent_name = self .add_remote_agent (metadata , p_remote_tp_rank ,
638- remote_tp_size , p_remote_pp_rank ,
658+ remote_agent_name = self .add_remote_agent (metadata ,
659+ p_remote_tp_rank ,
660+ remote_tp_size ,
661+ p_remote_pp_rank ,
639662 remote_pp_size )
640663 setup_agent_time = time .perf_counter ()
641664 logger .debug ("NIXL handshake: add agent took: %s" ,
@@ -677,11 +700,11 @@ def _background_nixl_handshake(self, req_id: str,
677700 fut = self ._handshake_initiation_executor .submit (
678701 self ._nixl_handshake , meta .remote_host , meta .remote_port ,
679702 meta .tp_size , meta .pp_size , remote_engine_id )
680- self ._handshake_futures [remote_engine_id ] = { self . pp_rank : fut }
703+ self ._handshake_futures [remote_engine_id ] = fut
681704
682705 def done_callback (f : Future [dict [int , str ]], eid = remote_engine_id ):
683706 with self ._handshake_lock :
684- del self ._handshake_futures [eid ][ self . pp_rank ]
707+ del self ._handshake_futures [eid ]
685708 try :
686709 self ._remote_agents [eid ][self .pp_rank ] = f .result ()
687710 except Exception :
@@ -834,7 +857,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
834857 ready_event = threading .Event ()
835858 self ._nixl_handshake_listener_t = threading .Thread (
836859 target = self ._nixl_handshake_listener ,
837- args = (metadata , ready_event , self .side_channel_port , self . tp_rank ),
860+ args = (metadata , ready_event , self .side_channel_port ),
838861 daemon = True ,
839862 name = "nixl_handshake_listener" )
840863 self ._nixl_handshake_listener_t .start ()
@@ -895,8 +918,10 @@ def add_remote_agent(self,
895918 """ # noqa: E501
896919 engine_id = nixl_agent_meta .engine_id
897920 # TODO re-evaluate refreshing for scaling/recovery
898- if remote_tp_rank in self ._remote_agents .get (engine_id , {}).get (remote_pp_rank , {}):
899- return self ._remote_agents [engine_id ][remote_pp_rank ][remote_tp_rank ]
921+ if remote_tp_rank in self ._remote_agents .get (engine_id , {}).get (
922+ remote_pp_rank , {}):
923+ return self ._remote_agents [engine_id ][remote_pp_rank ][
924+ remote_tp_rank ]
900925
901926 if engine_id not in self ._tp_size :
902927 self ._tp_size [engine_id ] = remote_tp_size
@@ -969,12 +994,13 @@ def add_remote_agent(self,
969994 # self.block_len == remote_block_len//tp_ratio bytes.
970995 addr = base_addr + block_offset + rank_offset
971996 # (addr, len, device id)
972- # blocks_data.append((addr, self.block_len, remote_tp_rank))
973- blocks_data . append (( addr , self . block_len , remote_tp_rank + remote_pp_rank * remote_tp_size ))
997+ blocks_data .append ((addr , self .block_len ,
998+ remote_tp_rank + remote_pp_rank * remote_tp_size ))
974999 logger .debug (
9751000 "Created %s blocks for dst engine %s with remote rank %s and "
976- "tp local rank %s, device id %s" , len (blocks_data ), engine_id , remote_tp_rank ,
977- self .tp_rank , remote_tp_rank + remote_pp_rank * remote_tp_size )
1001+ "tp local rank %s, device id %s" , len (blocks_data ), engine_id ,
1002+ remote_tp_rank , self .tp_rank ,
1003+ remote_tp_rank + remote_pp_rank * remote_tp_size )
9781004
9791005 # Register with NIXL.
9801006 descs = self .nixl_wrapper .get_xfer_descs (blocks_data ,
@@ -1093,10 +1119,8 @@ def _pop_done_transfers(
10931119 xfer_state = self .nixl_wrapper .check_xfer_state (handle )
10941120 if xfer_state == "DONE" :
10951121 self .nixl_wrapper .release_xfer_handle (handle )
1096- logger .info (f"============transfer req_id { req_id } done" )
10971122 elif xfer_state == "PROC" :
10981123 in_progress = True
1099- logger .info (f"============transfer req_id { req_id } processing" )
11001124 continue
11011125 else :
11021126 raise RuntimeError ("Transfer failed with state %s" ,
@@ -1173,8 +1197,9 @@ def _read_blocks(self, local_block_ids: list[int],
11731197 num_local_blocks = len (local_block_ids )
11741198 if num_local_blocks == 0 :
11751199 remote_rank = self .tp_rank // tp_ratio
1176- remote_pp_rank = self .pp_rank # don't consider heterogeneous PP now
1177- agent_name = self ._remote_agents [dst_engine_id ][remote_pp_rank ][remote_rank ]
1200+ remote_pp_rank = self .pp_rank # don't consider heterogeneous PP now
1201+ agent_name = self ._remote_agents [dst_engine_id ][remote_pp_rank ][
1202+ remote_rank ]
11781203 self .nixl_wrapper .send_notif (agent_name , notif_msg = notif_id )
11791204 return
11801205
0 commit comments