From ed7fe6b6b53806803cd98c85e2e2b5484bd2e549 Mon Sep 17 00:00:00 2001 From: Buke Ao Date: Mon, 15 Sep 2025 10:59:52 -0700 Subject: [PATCH 1/3] first commit for heterogeneous PD: accommodate block_size gap --- .../kv_connector/v1/nixl_connector.py | 69 +++++++++++++------ 1 file changed, 49 insertions(+), 20 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 2bcee6ba23d2..f7c2d6299aca 100755 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -216,7 +216,7 @@ def start_load_kv(self, forward_context: "ForwardContext", assert self.connector_worker is not None assert isinstance(self._connector_metadata, NixlConnectorMetadata) self.connector_worker.start_load_kv(self._connector_metadata) - logger.info(f"libin debug start_load_kv return {os.getenv('RANK')}, takes {time.perf_counter() - s1}") + #logger.info(f"libin debug start_load_kv return {os.getenv('RANK')}, takes {time.perf_counter() - s1}") def wait_for_layer_load(self, layer_name: str) -> None: """NixlConnector does not do layerwise saving.""" @@ -449,7 +449,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Config. self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size - + self.block_factor = 8 # A100.block_size/G2.block_size + self.block_shape = None # Agent. self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. @@ -786,6 +787,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): block_shape[0] = block_shape[0] // self.num_blocks block_shape = torch.Size(block_shape) block_size, n_kv_heads, head_dim = block_shape[-3:] + self.block_shape = [block_size, n_kv_heads, head_dim] # head size in bytes. self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim else: @@ -802,7 +804,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): "per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device, self.use_host_buffer, self.num_blocks, block_shape, first_kv_cache[0].shape) - self.dst_num_blocks[self.engine_id] = self.num_blocks + self.dst_num_blocks[self.engine_id] = self.num_blocks * self.block_factor self.device_kv_caches = kv_caches kv_caches_base_addr = [] caches_data = [] @@ -870,15 +872,15 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # could create fewer, but then _get_block_descs_ids needs to # select agent_meta.num_blocks instead of self.num_blocks for # local descr, and that makes handling regular flow less clean. - for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len + for block_id in range(self.num_blocks * self.block_factor): + block_offset = block_id * self.block_len // self.block_factor addr = base_addr + block_offset # (addr, len, device id) # TODO: does device_id matter to DRAM? - blocks_data.append((addr, self.block_len, self.tp_rank)) + blocks_data.append((addr, self.block_len//self.block_factor, self.tp_rank)) logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.tp_rank) - + print(f'buke: {blocks_data[0:10]=}') descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) # NIXL_INIT_AGENT to be used for preparations of local descs. @@ -955,7 +957,7 @@ def add_remote_agent(self, assert self._tp_size[engine_id] == remote_tp_size # We may eventually enable this after asserting equality in cache # layout and close outputs. - assert nixl_agent_meta.attn_backend_name == self.backend_name + #assert nixl_agent_meta.attn_backend_name == self.backend_name remote_agent_name = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata) @@ -984,14 +986,14 @@ def add_remote_agent(self, # Account for joint KV in FlashInfer. remote_block_size //= 2 - assert nixl_agent_meta.block_len == self.block_len * tp_ratio, ( - "Remote P worker KV layer cache must be of shape [2, N, " - "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." - ) + #assert nixl_agent_meta.block_len == self.block_len * tp_ratio, ( + # "Remote P worker KV layer cache must be of shape [2, N, " + # "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." + #) - assert self.block_size == remote_block_size, ( - "Remote P worker with different block size is not supported " - f"{self.block_size=} {remote_block_size=}") + #assert self.block_size == remote_block_size, ( + # "Remote P worker with different block size is not supported " + # f"{self.block_size=} {remote_block_size=}") # Create dst descs and xfer side handles. TP workers have same #blocks. if engine_id in self.dst_num_blocks: @@ -1018,7 +1020,7 @@ def add_remote_agent(self, # self.block_len == remote_block_len//tp_ratio bytes. addr = base_addr + block_offset + rank_offset # (addr, len, device id) - blocks_data.append((addr, self.block_len, remote_tp_rank)) + blocks_data.append((addr, nixl_agent_meta.block_len, remote_tp_rank)) logger.debug( "Created %s blocks for dst engine %s with remote rank %s and " "local rank %s", len(blocks_data), engine_id, remote_tp_rank, @@ -1081,6 +1083,24 @@ def get_finished(self) -> tuple[set[str], set[str]]: "Rank %s, get_finished: %s requests done sending " "and %s requests done recving", self.tp_rank, len(done_sending), len(done_recving)) + #import remote_pdb; remote_pdb.set_trace() + t1 = time.perf_counter() + remote_block_size = self.block_size // self.block_factor + block_size, n_kv_heads, head_dim = self.block_shape + for req_id in done_recving: + #print(req_id, self._recving_metadata) + meta = self._recving_metadata.pop(req_id) + for k, v in self.device_kv_caches.values(): + local_block_ids = meta.local_block_ids + #print(f'buke {local_block_ids=}|{k.shape=}') + for block_idx in local_block_ids: + #import remote_pdb; remote_pdb.set_trace() + k[block_idx*self.block_size: (1+block_idx)*self.block_size] = k[block_idx*self.block_size: (1+block_idx)*self.block_size].reshape(self.block_factor, n_kv_heads, remote_block_size, head_dim).permute(0,2,1,3).contiguous().reshape(self.block_size,n_kv_heads,head_dim) + v[block_idx*self.block_size: (1+block_idx)*self.block_size] = v[block_idx*self.block_size: (1+block_idx)*self.block_size].reshape(self.block_factor, n_kv_heads, remote_block_size, head_dim).permute(0,2,1,3).contiguous().reshape(self.block_size,n_kv_heads,head_dim) + #import remote_pdb; remote_pdb.set_trace() + t2 = time.perf_counter() + tt = t2-t1 + print(f'buke permute time:{tt}, {req_id=}|{self._recving_metadata=}') if self.use_host_buffer: for req_id in done_recving: s2 = time.perf_counter() @@ -1178,7 +1198,8 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, remote_engine_id, len(meta.local_block_ids), len(meta.remote_block_ids)) - if self.use_host_buffer: + is_hetero = True + if is_hetero or self.use_host_buffer: self._recving_metadata[req_id] = meta if remote_engine_id not in self._remote_agents: # Initiate handshake with remote engine to exchange metadata. @@ -1245,8 +1266,8 @@ def _read_blocks(self, local_block_ids: list[int], # Partial prefix cache hit: just read uncomputed blocks. num_remote_blocks = len(remote_block_ids) assert num_local_blocks <= num_remote_blocks - if num_local_blocks < num_remote_blocks: - remote_block_ids = remote_block_ids[-num_local_blocks:] + #if num_local_blocks < num_remote_blocks: + # remote_block_ids = remote_block_ids[-num_local_blocks:] # Get side handles. local_xfer_side_handle = self.src_xfer_side_handle @@ -1259,12 +1280,19 @@ def _read_blocks(self, local_block_ids: list[int], # Get descs ids. local_block_descs_ids: list[int] = [] remote_block_descs_ids: list[int] = [] + local_sub_block_ids: list[int] = [] + + print('buke: ', remote_block_ids) + remote_block_ids = remote_block_ids[:(len(remote_block_ids)//self.block_factor)*self.block_factor] + for index in range(0, len(remote_block_ids)): + local_sub_block_ids.append(local_block_ids[index//self.block_factor]*self.block_factor + index) + print(f'buke {local_block_ids=} |{remote_block_ids=} |{local_sub_block_ids=}') if not self.block_window_per_layer: # Default case: assume global attention remote_block_descs_ids = self._get_block_descs_ids( dst_engine_id, remote_block_ids) local_block_descs_ids = self._get_block_descs_ids( - self.engine_id, local_block_ids) + self.engine_id, local_sub_block_ids) else: # TODO(mgoin): remove this once we have hybrid memory allocator # Optimization for models with local attention (Llama 4) @@ -1303,6 +1331,7 @@ def _read_blocks(self, local_block_ids: list[int], ) # Begin async xfer. + print('buke ->>>> transfer start >>>---') self.nixl_wrapper.transfer(handle) # Use handle to check completion in future step(). From 40a53a513ad21301cbfda145645fa483e1e573e2 Mon Sep 17 00:00:00 2001 From: Buke Ao Date: Thu, 18 Sep 2025 10:22:05 -0700 Subject: [PATCH 2/3] fix indexing bug and enable the 1P1D mode, where P uses TP=1 and D uses TP=2 --- .../nixl_integration/toy_proxy_server.py | 12 ++++++--- .../kv_connector/v1/nixl_connector.py | 27 +++++++++---------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py index 163c5547a914..5b3e60b010f8 100644 --- a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py +++ b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py @@ -212,7 +212,6 @@ async def _handle_completions(api: str, request: Request): p = send_request_to_service(prefill_client_info, api, req_data, request_id) s2 = time.perf_counter() - print(f'libin proxy send to prefill {s2-s1}') sys.stdout.flush() response = await p s3 = time.perf_counter() @@ -220,8 +219,15 @@ async def _handle_completions(api: str, request: Request): response_json = response.json() kv_transfer_params = response_json.get('kv_transfer_params', {}) if kv_transfer_params: + remote_block_len = len(kv_transfer_params['remote_block_ids']) + logger.debug('buke: cut:', type(kv_transfer_params), kv_transfer_params['remote_block_ids'],kv_transfer_params['remote_block_ids'][:(remote_block_len//8)*8]) + + kv_transfer_params['remote_block_ids'] = kv_transfer_params['remote_block_ids'][:(remote_block_len//8)*8] + if remote_block_len % 8 == 0: + kv_transfer_params['remote_block_ids'] = kv_transfer_params['remote_block_ids'][:(remote_block_len//8)*8-1] + logger.info('buke hit corner case multiples of 8:', remote_block_len) req_data["kv_transfer_params"] = kv_transfer_params - + #print(req_data) # Get the next decode client in round-robin fashion decode_client_info = get_next_client(request.app, 'decode') @@ -238,7 +244,6 @@ async def generate_stream(): if is_first is False: s4 = time.perf_counter() - print(f'libin debug proxy receive decode 1 total:{s4-s1}| prefill:{s3-s1}| in-between:{s6-s3}|decode:{s4-s6}| {s6=} {s4=}') sys.stdout.flush() is_first = True yield chunk @@ -246,7 +251,6 @@ async def generate_stream(): re = StreamingResponse(generate_stream(), media_type="application/json") s5 = time.perf_counter() - #print(f'libin debug proxy receive decode2 {s5-s1} {s5-s6}') #sys.stdout.flush() return re diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index f7c2d6299aca..852a18099c2f 100755 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -623,7 +623,6 @@ def _nixl_handshake( path = make_zmq_path("tcp", host, port + p_remote_rank) logger.debug("Querying metadata on path: %s at remote rank %s", path, p_remote_rank) - # Send query for the request. with zmq_ctx(zmq.REQ, path) as sock: sock.send(GET_META_MSG) @@ -817,6 +816,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # (roughly 8KB vs 5KB). # Conversely for FlashInfer, K and V are transferred in the same tensor # to better exploit the memory layout (ie num_blocks is the first dim). + for cache_or_caches in xfer_buffers.values(): # Normalize to always be a list of caches cache_list = [cache_or_caches] if use_mla \ @@ -873,14 +873,14 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # select agent_meta.num_blocks instead of self.num_blocks for # local descr, and that makes handling regular flow less clean. for block_id in range(self.num_blocks * self.block_factor): - block_offset = block_id * self.block_len // self.block_factor + block_offset = block_id * self.block_len // (self.block_factor) addr = base_addr + block_offset # (addr, len, device id) # TODO: does device_id matter to DRAM? - blocks_data.append((addr, self.block_len//self.block_factor, self.tp_rank)) + blocks_data.append((addr, self.block_len//(self.block_factor), self.tp_rank)) logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.tp_rank) - print(f'buke: {blocks_data[0:10]=}') + #print(f'buke: {blocks_data[0:10]=}') descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) # NIXL_INIT_AGENT to be used for preparations of local descs. @@ -1009,7 +1009,7 @@ def add_remote_agent(self, # Only register the remote's descriptors if current rank pulls from it. self.kv_caches_base_addr[ engine_id] = nixl_agent_meta.kv_caches_base_addr - rank_offset = self.tp_rank % tp_ratio * self.block_len \ + rank_offset = self.tp_rank % tp_ratio * nixl_agent_meta.block_len // tp_ratio \ if not (self.use_mla or is_kv_replicated) else 0 # Register all remote blocks, but only the corresponding kv heads. for base_addr in nixl_agent_meta.kv_caches_base_addr: @@ -1020,15 +1020,16 @@ def add_remote_agent(self, # self.block_len == remote_block_len//tp_ratio bytes. addr = base_addr + block_offset + rank_offset # (addr, len, device id) - blocks_data.append((addr, nixl_agent_meta.block_len, remote_tp_rank)) + blocks_data.append((addr, nixl_agent_meta.block_len//tp_ratio, remote_tp_rank)) logger.debug( "Created %s blocks for dst engine %s with remote rank %s and " "local rank %s", len(blocks_data), engine_id, remote_tp_rank, self.tp_rank) - + logger.debug(f'buke {self.slot_size_bytes=}|{tp_ratio=}|{self.block_len=}|{nixl_agent_meta.block_len=}|{self.tp_rank=}|{self._use_flashinfer=}') # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) + #print('buke register remote:', len(blocks_data), blocks_data[:10],blocks_data[-1],self.nixl_memory_type) self.dst_xfer_side_handles[ engine_id] = self.nixl_wrapper.prep_xfer_dlist( remote_agent_name, descs) @@ -1100,7 +1101,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: #import remote_pdb; remote_pdb.set_trace() t2 = time.perf_counter() tt = t2-t1 - print(f'buke permute time:{tt}, {req_id=}|{self._recving_metadata=}') + logger.debug(f'buke permute time:{tt}, {req_id=}|{self._recving_metadata=}') if self.use_host_buffer: for req_id in done_recving: s2 = time.perf_counter() @@ -1282,11 +1283,11 @@ def _read_blocks(self, local_block_ids: list[int], remote_block_descs_ids: list[int] = [] local_sub_block_ids: list[int] = [] - print('buke: ', remote_block_ids) + #print('buke: ', remote_block_ids) remote_block_ids = remote_block_ids[:(len(remote_block_ids)//self.block_factor)*self.block_factor] - for index in range(0, len(remote_block_ids)): - local_sub_block_ids.append(local_block_ids[index//self.block_factor]*self.block_factor + index) - print(f'buke {local_block_ids=} |{remote_block_ids=} |{local_sub_block_ids=}') + for index,remote_block_id in enumerate(remote_block_ids): + local_sub_block_ids.append(local_block_ids[index//self.block_factor]*self.block_factor + index%self.block_factor) + logger.debug(f'buke {local_block_ids=} |{remote_block_ids=} |{local_sub_block_ids=}') if not self.block_window_per_layer: # Default case: assume global attention remote_block_descs_ids = self._get_block_descs_ids( @@ -1331,7 +1332,6 @@ def _read_blocks(self, local_block_ids: list[int], ) # Begin async xfer. - print('buke ->>>> transfer start >>>---') self.nixl_wrapper.transfer(handle) # Use handle to check completion in future step(). @@ -1379,7 +1379,6 @@ def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: if socket_type not in (zmq.ROUTER, zmq.REQ): raise ValueError(f"Unexpected socket type: {socket_type}") - ctx: Optional[zmq.Context] = None try: ctx = zmq.Context() # type: ignore[attr-defined] From 1744167aaf0b7bce8f470d6c3e5018f76a4a4951 Mon Sep 17 00:00:00 2001 From: Buke Ao Date: Thu, 25 Sep 2025 16:06:18 -0700 Subject: [PATCH 3/3] add backend_name check --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 852a18099c2f..dc8190b2f20b 100755 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -957,7 +957,7 @@ def add_remote_agent(self, assert self._tp_size[engine_id] == remote_tp_size # We may eventually enable this after asserting equality in cache # layout and close outputs. - #assert nixl_agent_meta.attn_backend_name == self.backend_name + assert nixl_agent_meta.attn_backend_name == "FLASH_ATTN_VLLM_V1" or nixl_agent_meta.attn_backend_name == "HPU_ATTN_V1" remote_agent_name = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata)