-
Notifications
You must be signed in to change notification settings - Fork 134
hetero pd2 #1980
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: libint/debug_ttft
Are you sure you want to change the base?
hetero pd2 #1980
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -205,18 +205,18 @@ | |
| finished_req_ids: set[str]) -> tuple[set[str], set[str]]: | ||
| """Get the finished recving and sending requests.""" | ||
| s1 = time.perf_counter() | ||
| assert self.connector_worker is not None | ||
| re= self.connector_worker.get_finished() | ||
| #logger.info(f'libin debug get_finished {os.getenv('RANK')}, takes {time.perf_counter() - s1}') | ||
| return re | ||
|
|
||
| def start_load_kv(self, forward_context: "ForwardContext", | ||
| **kwargs) -> None: | ||
| s1 = time.perf_counter() | ||
| assert self.connector_worker is not None | ||
| assert isinstance(self._connector_metadata, NixlConnectorMetadata) | ||
| self.connector_worker.start_load_kv(self._connector_metadata) | ||
| logger.info(f"libin debug start_load_kv return {os.getenv('RANK')}, takes {time.perf_counter() - s1}") | ||
| #logger.info(f"libin debug start_load_kv return {os.getenv('RANK')}, takes {time.perf_counter() - s1}") | ||
|
|
||
| def wait_for_layer_load(self, layer_name: str) -> None: | ||
| """NixlConnector does not do layerwise saving.""" | ||
|
|
@@ -449,7 +449,8 @@ | |
| # Config. | ||
| self.vllm_config = vllm_config | ||
| self.block_size = vllm_config.cache_config.block_size | ||
|
|
||
| self.block_factor = 8 # A100.block_size/G2.block_size | ||
| self.block_shape = None | ||
| # Agent. | ||
| self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) | ||
| # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. | ||
|
|
@@ -622,7 +623,6 @@ | |
| path = make_zmq_path("tcp", host, port + p_remote_rank) | ||
| logger.debug("Querying metadata on path: %s at remote rank %s", path, | ||
| p_remote_rank) | ||
|
|
||
| # Send query for the request. | ||
| with zmq_ctx(zmq.REQ, path) as sock: | ||
| sock.send(GET_META_MSG) | ||
|
|
@@ -786,6 +786,7 @@ | |
| block_shape[0] = block_shape[0] // self.num_blocks | ||
| block_shape = torch.Size(block_shape) | ||
| block_size, n_kv_heads, head_dim = block_shape[-3:] | ||
| self.block_shape = [block_size, n_kv_heads, head_dim] | ||
| # head size in bytes. | ||
| self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim | ||
| else: | ||
|
|
@@ -802,7 +803,7 @@ | |
| "per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device, | ||
| self.use_host_buffer, self.num_blocks, block_shape, | ||
| first_kv_cache[0].shape) | ||
| self.dst_num_blocks[self.engine_id] = self.num_blocks | ||
| self.dst_num_blocks[self.engine_id] = self.num_blocks * self.block_factor | ||
| self.device_kv_caches = kv_caches | ||
| kv_caches_base_addr = [] | ||
| caches_data = [] | ||
|
|
@@ -815,6 +816,7 @@ | |
| # (roughly 8KB vs 5KB). | ||
| # Conversely for FlashInfer, K and V are transferred in the same tensor | ||
| # to better exploit the memory layout (ie num_blocks is the first dim). | ||
|
|
||
| for cache_or_caches in xfer_buffers.values(): | ||
| # Normalize to always be a list of caches | ||
| cache_list = [cache_or_caches] if use_mla \ | ||
|
|
@@ -870,15 +872,15 @@ | |
| # could create fewer, but then _get_block_descs_ids needs to | ||
| # select agent_meta.num_blocks instead of self.num_blocks for | ||
| # local descr, and that makes handling regular flow less clean. | ||
| for block_id in range(self.num_blocks): | ||
| block_offset = block_id * self.block_len | ||
| for block_id in range(self.num_blocks * self.block_factor): | ||
| block_offset = block_id * self.block_len // (self.block_factor) | ||
| addr = base_addr + block_offset | ||
| # (addr, len, device id) | ||
| # TODO: does device_id matter to DRAM? | ||
| blocks_data.append((addr, self.block_len, self.tp_rank)) | ||
| blocks_data.append((addr, self.block_len//(self.block_factor), self.tp_rank)) | ||
| logger.debug("Created %s blocks for src engine %s and rank %s", | ||
| len(blocks_data), self.engine_id, self.tp_rank) | ||
|
|
||
| #print(f'buke: {blocks_data[0:10]=}') | ||
| descs = self.nixl_wrapper.get_xfer_descs(blocks_data, | ||
| self.nixl_memory_type) | ||
| # NIXL_INIT_AGENT to be used for preparations of local descs. | ||
|
|
@@ -955,7 +957,7 @@ | |
| assert self._tp_size[engine_id] == remote_tp_size | ||
| # We may eventually enable this after asserting equality in cache | ||
| # layout and close outputs. | ||
| assert nixl_agent_meta.attn_backend_name == self.backend_name | ||
| assert nixl_agent_meta.attn_backend_name == "FLASH_ATTN_VLLM_V1" or nixl_agent_meta.attn_backend_name == "HPU_ATTN_V1" | ||
|
|
||
| remote_agent_name = self.nixl_wrapper.add_remote_agent( | ||
| nixl_agent_meta.agent_metadata) | ||
|
|
@@ -984,14 +986,14 @@ | |
| # Account for joint KV in FlashInfer. | ||
| remote_block_size //= 2 | ||
|
|
||
| assert nixl_agent_meta.block_len == self.block_len * tp_ratio, ( | ||
| "Remote P worker KV layer cache must be of shape [2, N, " | ||
| "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." | ||
| ) | ||
| #assert nixl_agent_meta.block_len == self.block_len * tp_ratio, ( | ||
| # "Remote P worker KV layer cache must be of shape [2, N, " | ||
| # "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." | ||
| #) | ||
|
|
||
| assert self.block_size == remote_block_size, ( | ||
| "Remote P worker with different block size is not supported " | ||
| f"{self.block_size=} {remote_block_size=}") | ||
| #assert self.block_size == remote_block_size, ( | ||
| # "Remote P worker with different block size is not supported " | ||
| # f"{self.block_size=} {remote_block_size=}") | ||
|
|
||
| # Create dst descs and xfer side handles. TP workers have same #blocks. | ||
| if engine_id in self.dst_num_blocks: | ||
|
|
@@ -1007,7 +1009,7 @@ | |
| # Only register the remote's descriptors if current rank pulls from it. | ||
| self.kv_caches_base_addr[ | ||
| engine_id] = nixl_agent_meta.kv_caches_base_addr | ||
| rank_offset = self.tp_rank % tp_ratio * self.block_len \ | ||
| rank_offset = self.tp_rank % tp_ratio * nixl_agent_meta.block_len // tp_ratio \ | ||
| if not (self.use_mla or is_kv_replicated) else 0 | ||
| # Register all remote blocks, but only the corresponding kv heads. | ||
| for base_addr in nixl_agent_meta.kv_caches_base_addr: | ||
|
|
@@ -1018,15 +1020,16 @@ | |
| # self.block_len == remote_block_len//tp_ratio bytes. | ||
| addr = base_addr + block_offset + rank_offset | ||
| # (addr, len, device id) | ||
| blocks_data.append((addr, self.block_len, remote_tp_rank)) | ||
| blocks_data.append((addr, nixl_agent_meta.block_len//tp_ratio, remote_tp_rank)) | ||
| logger.debug( | ||
| "Created %s blocks for dst engine %s with remote rank %s and " | ||
| "local rank %s", len(blocks_data), engine_id, remote_tp_rank, | ||
| self.tp_rank) | ||
|
|
||
| logger.debug(f'buke {self.slot_size_bytes=}|{tp_ratio=}|{self.block_len=}|{nixl_agent_meta.block_len=}|{self.tp_rank=}|{self._use_flashinfer=}') | ||
| # Register with NIXL. | ||
| descs = self.nixl_wrapper.get_xfer_descs(blocks_data, | ||
| self.nixl_memory_type) | ||
| #print('buke register remote:', len(blocks_data), blocks_data[:10],blocks_data[-1],self.nixl_memory_type) | ||
| self.dst_xfer_side_handles[ | ||
| engine_id] = self.nixl_wrapper.prep_xfer_dlist( | ||
| remote_agent_name, descs) | ||
|
|
@@ -1081,6 +1084,24 @@ | |
| "Rank %s, get_finished: %s requests done sending " | ||
| "and %s requests done recving", self.tp_rank, | ||
| len(done_sending), len(done_recving)) | ||
| #import remote_pdb; remote_pdb.set_trace() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add check remote is gpu attention? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, i can add this |
||
| t1 = time.perf_counter() | ||
| remote_block_size = self.block_size // self.block_factor | ||
| block_size, n_kv_heads, head_dim = self.block_shape | ||
| for req_id in done_recving: | ||
| #print(req_id, self._recving_metadata) | ||
| meta = self._recving_metadata.pop(req_id) | ||
| for k, v in self.device_kv_caches.values(): | ||
| local_block_ids = meta.local_block_ids | ||
| #print(f'buke {local_block_ids=}|{k.shape=}') | ||
| for block_idx in local_block_ids: | ||
| #import remote_pdb; remote_pdb.set_trace() | ||
| k[block_idx*self.block_size: (1+block_idx)*self.block_size] = k[block_idx*self.block_size: (1+block_idx)*self.block_size].reshape(self.block_factor, n_kv_heads, remote_block_size, head_dim).permute(0,2,1,3).contiguous().reshape(self.block_size,n_kv_heads,head_dim) | ||
| v[block_idx*self.block_size: (1+block_idx)*self.block_size] = v[block_idx*self.block_size: (1+block_idx)*self.block_size].reshape(self.block_factor, n_kv_heads, remote_block_size, head_dim).permute(0,2,1,3).contiguous().reshape(self.block_size,n_kv_heads,head_dim) | ||
| #import remote_pdb; remote_pdb.set_trace() | ||
| t2 = time.perf_counter() | ||
| tt = t2-t1 | ||
| logger.debug(f'buke permute time:{tt}, {req_id=}|{self._recving_metadata=}') | ||
| if self.use_host_buffer: | ||
| for req_id in done_recving: | ||
| s2 = time.perf_counter() | ||
|
|
@@ -1178,7 +1199,8 @@ | |
| "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, | ||
| remote_engine_id, len(meta.local_block_ids), | ||
| len(meta.remote_block_ids)) | ||
| if self.use_host_buffer: | ||
| is_hetero = True | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. whats the plan to set this variable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need the flag to enable the _recving_metadata There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are u going to add it as env flag ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we reuse the DECODE_TP_RATIO env There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems good There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. check if this can be moved to hpu_model_runner |
||
| if is_hetero or self.use_host_buffer: | ||
| self._recving_metadata[req_id] = meta | ||
| if remote_engine_id not in self._remote_agents: | ||
| # Initiate handshake with remote engine to exchange metadata. | ||
|
|
@@ -1245,8 +1267,8 @@ | |
| # Partial prefix cache hit: just read uncomputed blocks. | ||
| num_remote_blocks = len(remote_block_ids) | ||
| assert num_local_blocks <= num_remote_blocks | ||
| if num_local_blocks < num_remote_blocks: | ||
| remote_block_ids = remote_block_ids[-num_local_blocks:] | ||
| #if num_local_blocks < num_remote_blocks: | ||
| # remote_block_ids = remote_block_ids[-num_local_blocks:] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add check for heter |
||
|
|
||
| # Get side handles. | ||
| local_xfer_side_handle = self.src_xfer_side_handle | ||
|
|
@@ -1259,12 +1281,19 @@ | |
| # Get descs ids. | ||
| local_block_descs_ids: list[int] = [] | ||
| remote_block_descs_ids: list[int] = [] | ||
| local_sub_block_ids: list[int] = [] | ||
|
|
||
| #print('buke: ', remote_block_ids) | ||
| remote_block_ids = remote_block_ids[:(len(remote_block_ids)//self.block_factor)*self.block_factor] | ||
| for index,remote_block_id in enumerate(remote_block_ids): | ||
| local_sub_block_ids.append(local_block_ids[index//self.block_factor]*self.block_factor + index%self.block_factor) | ||
| logger.debug(f'buke {local_block_ids=} |{remote_block_ids=} |{local_sub_block_ids=}') | ||
| if not self.block_window_per_layer: | ||
| # Default case: assume global attention | ||
| remote_block_descs_ids = self._get_block_descs_ids( | ||
| dst_engine_id, remote_block_ids) | ||
| local_block_descs_ids = self._get_block_descs_ids( | ||
| self.engine_id, local_block_ids) | ||
| self.engine_id, local_sub_block_ids) | ||
| else: | ||
| # TODO(mgoin): remove this once we have hybrid memory allocator | ||
| # Optimization for models with local attention (Llama 4) | ||
|
|
@@ -1350,7 +1379,6 @@ | |
|
|
||
| if socket_type not in (zmq.ROUTER, zmq.REQ): | ||
| raise ValueError(f"Unexpected socket type: {socket_type}") | ||
|
|
||
| ctx: Optional[zmq.Context] = None | ||
| try: | ||
| ctx = zmq.Context() # type: ignore[attr-defined] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it going to be hardcoded value ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's hardcode for now, maybe it's ok since this number won't change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's better to check block size on both
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can check if the remote.block_size is expected. we don't know the remote.block_size here because the handshake occures afterwards.