Skip to content

Commit 12949c0

Browse files
committed
add hpu functions
1 parent e1404f3 commit 12949c0

File tree

2 files changed

+79
-4
lines changed

2 files changed

+79
-4
lines changed

vllm/envs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@
121121
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
122122
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
123123
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
124+
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
124125
VLLM_ALL2ALL_BACKEND: str = "naive"
125126
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
126127
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
@@ -832,6 +833,13 @@ def get_vllm_port() -> Optional[int]:
832833
"VLLM_NIXL_SIDE_CHANNEL_HOST":
833834
lambda: os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST", "localhost"),
834835

836+
# Time (in seconds) after which the KV cache on the producer side is
837+
# automatically cleared if no READ notification is received from the
838+
# consumer. This is only applicable when using NixlConnector in a
839+
# disaggregated decode-prefill setup.
840+
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT":
841+
lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120")),
842+
835843
# Port used for NIXL handshake between remote agents.
836844
"VLLM_NIXL_SIDE_CHANNEL_PORT":
837845
lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")),

vllm/v1/worker/hpu_model_runner.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import os
99
import time
1010
from dataclasses import dataclass, field, fields
11-
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeAlias, Union
11+
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeAlias, Union, Literal
1212

1313
import habana_frameworks.torch as htorch
1414
import habana_frameworks.torch.internal.bridge_config as bc
@@ -865,7 +865,7 @@ def _get_prompts_and_decodes(
865865
assert num_reqs > 0
866866

867867
if scheduler_output.kv_connector_metadata:
868-
requests = scheduler_output.kv_connector_metadata.requests
868+
requests = scheduler_output.kv_connector_metadata.reqs_to_save
869869
else:
870870
requests = None
871871

@@ -878,9 +878,9 @@ def _get_prompts_and_decodes(
878878

879879
if requests is not None and req_id not in self.input_batch.req_type:
880880
for request in requests:
881-
if request.req_id == req_id:
881+
if request == req_id:
882882
self.input_batch.req_type[req_id] = "prefill" \
883-
if request.load_spec is None else "decode"
883+
if request is not None else "decode"
884884
break
885885

886886
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
@@ -2433,6 +2433,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
24332433
#import remote_pdb; remote_pdb.set_trace()
24342434
kv_caches = { layer: torch.stack((tup[0], tup[1])) for layer,tup in kv_caches.items()}
24352435
get_kv_transfer_group().register_kv_caches(kv_caches)
2436+
get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks)
24362437

24372438
htorch.hpu.synchronize()
24382439

@@ -2473,3 +2474,69 @@ def kv_connector_no_forward(
24732474
output.finished_sending = finished_sending
24742475
output.finished_recving = finished_recving
24752476
return output
2477+
2478+
def _make_src_and_dst_indices(
2479+
src_block_ids: list[int],
2480+
dst_block_ids: list[int],
2481+
src_device: Union[torch.device, str],
2482+
dst_device: Union[torch.device, str],
2483+
) -> tuple[torch.Tensor, torch.Tensor]:
2484+
src_indices = torch.tensor(src_block_ids,
2485+
device=src_device,
2486+
dtype=torch.int64)
2487+
dst_indices = torch.tensor(dst_block_ids,
2488+
device=dst_device,
2489+
dtype=torch.int64)
2490+
return src_indices, dst_indices
2491+
2492+
2493+
def _insert_blocks_to_hpu(
2494+
cpu_cache: torch.Tensor,
2495+
hpu_cache: torch.Tensor,
2496+
cpu_block_indices: torch.Tensor,
2497+
hpu_block_indices: torch.Tensor,
2498+
) -> None:
2499+
torch.ops.xla.dynamo_set_buffer_donor_(hpu_cache, True)
2500+
hpu_cache[hpu_block_indices] = cpu_cache[cpu_block_indices].to(
2501+
hpu_cache.device)
2502+
2503+
2504+
def _swap_out_hpu_blocks(
2505+
hpu_cache: torch.Tensor,
2506+
cpu_cache: torch.Tensor,
2507+
hpu_block_indices: torch.Tensor,
2508+
cpu_block_indices: torch.Tensor,
2509+
) -> None:
2510+
""" tpu blocks to cpu blocks"""
2511+
torch.ops.xla.dynamo_set_buffer_donor_(hpu_cache, True)
2512+
cpu_cache[cpu_block_indices] = hpu_cache[hpu_block_indices].cpu()
2513+
2514+
def copy_kv_blocks(
2515+
src_kv_caches: dict[str, torch.Tensor],
2516+
dst_kv_caches: dict[str, torch.Tensor],
2517+
src_block_ids: list[int],
2518+
dst_block_ids: list[int],
2519+
direction: Literal["h2d", "d2h"],
2520+
) -> None:
2521+
"""Copy kv blocks between different buffers."""
2522+
if not src_kv_caches or not dst_kv_caches or \
2523+
not src_block_ids or not dst_block_ids or \
2524+
len(src_block_ids) != len(dst_block_ids):
2525+
return
2526+
2527+
src_device = next(iter(src_kv_caches.values())).device
2528+
dst_device = next(iter(dst_kv_caches.values())).device
2529+
2530+
src_indices, dst_indices = _make_src_and_dst_indices(
2531+
src_block_ids=src_block_ids,
2532+
dst_block_ids=dst_block_ids,
2533+
src_device=src_device,
2534+
dst_device=dst_device)
2535+
2536+
_copy_fn = _insert_blocks_to_hpu if direction == "h2d" else \
2537+
_swap_out_hpu_blocks
2538+
for layer_name in src_kv_caches:
2539+
src_tensor = src_kv_caches[layer_name]
2540+
dst_tensor = dst_kv_caches[layer_name]
2541+
_copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
2542+

0 commit comments

Comments
 (0)