diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index 5738c922b..66216a255 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -16,7 +16,6 @@ from vllm.distributed.parallel_state import get_tp_group, get_world_group from vllm.platforms import current_platform from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.request import Request from ucm.logger import init_logger from ucm.shared.metrics import ucmmonitor @@ -29,6 +28,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request logger = init_logger(__name__) @@ -178,11 +178,12 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.metrics_config, ) self.monitor = ucmmonitor.StatsMonitor.get_instance() - self.synchronize = ( - torch.cuda.synchronize - if current_platform.is_cuda_alike() - else torch.npu.synchronize - ) + + self.synchronize = ( + torch.cuda.synchronize + if current_platform.is_cuda_alike() + else torch.npu.synchronize + ) # invlalid block ids due to load errors self._invalid_block_ids: set[int] = set() @@ -558,7 +559,9 @@ def wait_for_save(self) -> None: # TODO support PP if (self.is_mla or self.is_dsa) and self.global_rank != 0: return - if self.metrics_config: + if self.metrics_config or current_platform.device_type == "npu": + # When use vllm_ascend, we should add synchronize here, otherwise accuracy problem will raise + # This has already been fixed in the latest main branch of vllm_ascend, so synchronize will no longer be needed in future versions. self.synchronize() metadata = self._get_connector_metadata()