Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions ucm/integration/vllm/ucm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
else torch.npu.synchronize
)

# invlalid block ids due to load errors
self._invalid_block_ids: set[int] = set()

def generate_hash(self, block_size: int, request: "Request") -> list[str]:
token_ids = request.all_token_ids

Expand Down Expand Up @@ -513,6 +516,9 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
# TODO error handling
if self.global_rank == 0 or not self.load_only_first_rank:
if self.store.wait(task) != 0:
self._invalid_block_ids.update(
metadata.request_meta[request_id].load_block_ids[1]
)
logger.error(f"request {request_id} load kv cache failed.")
if self.load_only_first_rank:
self._broadcast(req_broadcast_addr[request_id])
Expand Down Expand Up @@ -626,6 +632,18 @@ def wait_for_save(self) -> None:
def clear_connector_metadata(self) -> None:
super().clear_connector_metadata()

def get_block_ids_with_load_errors(self) -> set[int]:
"""
Get the set of block IDs that failed to load.

Returns:
Set of block IDs that encountered load errors.
Empty set if no load errors occurred.
"""
res = self._invalid_block_ids
self._invalid_block_ids = set()
return res


class UCMLayerWiseConnector(UCMDirectConnector):
"""
Expand Down Expand Up @@ -866,3 +884,13 @@ def clear_connector_metadata(self) -> None:
after the model execution.
"""
self.connector.clear_connector_metadata()

def get_block_ids_with_load_errors(self) -> set[int]:
"""
Get the set of block IDs that failed to load.

Returns:
Set of block IDs that encountered load errors.
Empty set if no load errors occurred.
"""
return self.connector.get_block_ids_with_load_errors()