Skip to content

Commit 094c136

Browse files
committed
adding indexing function
1 parent 12949c0 commit 094c136

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

vllm/v1/worker/hpu_model_runner.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2480,13 +2480,16 @@ def _make_src_and_dst_indices(
24802480
dst_block_ids: list[int],
24812481
src_device: Union[torch.device, str],
24822482
dst_device: Union[torch.device, str],
2483+
block_size: int,
24832484
) -> tuple[torch.Tensor, torch.Tensor]:
2484-
src_indices = torch.tensor(src_block_ids,
2485-
device=src_device,
2486-
dtype=torch.int64)
2487-
dst_indices = torch.tensor(dst_block_ids,
2488-
device=dst_device,
2489-
dtype=torch.int64)
2485+
2486+
for idx in range(len(src_block_ids)):
2487+
src_block_id = src_block_ids[idx]
2488+
src_indices = torch.range(block_size * src_block_id, block_size * (1 + src_block_id))
2489+
dst_block_id = dst_block_ids[idx]
2490+
dst_indices = torch.range(block_size * dst_block_id, block_size * (1 + dst_block_id))
2491+
2492+
24902493
return src_indices, dst_indices
24912494

24922495

@@ -2517,26 +2520,29 @@ def copy_kv_blocks(
25172520
src_block_ids: list[int],
25182521
dst_block_ids: list[int],
25192522
direction: Literal["h2d", "d2h"],
2523+
block_size: int
25202524
) -> None:
25212525
"""Copy kv blocks between different buffers."""
25222526
if not src_kv_caches or not dst_kv_caches or \
25232527
not src_block_ids or not dst_block_ids or \
25242528
len(src_block_ids) != len(dst_block_ids):
25252529
return
2526-
2530+
assert len(src_block_ids) == len(dst_block_ids)
25272531
src_device = next(iter(src_kv_caches.values())).device
25282532
dst_device = next(iter(dst_kv_caches.values())).device
25292533

25302534
src_indices, dst_indices = _make_src_and_dst_indices(
25312535
src_block_ids=src_block_ids,
25322536
dst_block_ids=dst_block_ids,
25332537
src_device=src_device,
2534-
dst_device=dst_device)
2535-
2536-
_copy_fn = _insert_blocks_to_hpu if direction == "h2d" else \
2537-
_swap_out_hpu_blocks
2538-
for layer_name in src_kv_caches:
2539-
src_tensor = src_kv_caches[layer_name]
2540-
dst_tensor = dst_kv_caches[layer_name]
2541-
_copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
2538+
dst_device=dst_device,
2539+
block_size)
2540+
2541+
for idx, (layer, kv_layer) in enumerate(src_kv_caches):
2542+
if direction == "h2d":
2543+
k, v = kv_layer[0], kv_layer[1]
2544+
else:
2545+
k, v = kv_layer
2546+
dst_kv_caches[layer][0][dst_indices].copy_(k[src_indices], non_blocking = False)
2547+
dst_kv_caches[layer][1][dst_indices].copy_(v[src_indices], non_blocking = False)
25422548

0 commit comments

Comments
 (0)