diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index e843a3e7e..5738c922b 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -184,6 +184,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): else torch.npu.synchronize ) + # invlalid block ids due to load errors + self._invalid_block_ids: set[int] = set() + def generate_hash(self, block_size: int, request: "Request") -> list[str]: token_ids = request.all_token_ids @@ -513,6 +516,9 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: # TODO error handling if self.global_rank == 0 or not self.load_only_first_rank: if self.store.wait(task) != 0: + self._invalid_block_ids.update( + metadata.request_meta[request_id].load_block_ids[1] + ) logger.error(f"request {request_id} load kv cache failed.") if self.load_only_first_rank: self._broadcast(req_broadcast_addr[request_id]) @@ -626,6 +632,18 @@ def wait_for_save(self) -> None: def clear_connector_metadata(self) -> None: super().clear_connector_metadata() + def get_block_ids_with_load_errors(self) -> set[int]: + """ + Get the set of block IDs that failed to load. + + Returns: + Set of block IDs that encountered load errors. + Empty set if no load errors occurred. + """ + res = self._invalid_block_ids + self._invalid_block_ids = set() + return res + class UCMLayerWiseConnector(UCMDirectConnector): """ @@ -866,3 +884,13 @@ def clear_connector_metadata(self) -> None: after the model execution. """ self.connector.clear_connector_metadata() + + def get_block_ids_with_load_errors(self) -> set[int]: + """ + Get the set of block IDs that failed to load. + + Returns: + Set of block IDs that encountered load errors. + Empty set if no load errors occurred. + """ + return self.connector.get_block_ids_with_load_errors()