@@ -2480,13 +2480,16 @@ def _make_src_and_dst_indices(
24802480 dst_block_ids : list [int ],
24812481 src_device : Union [torch .device , str ],
24822482 dst_device : Union [torch .device , str ],
2483+ block_size : int ,
24832484) -> tuple [torch .Tensor , torch .Tensor ]:
2484- src_indices = torch .tensor (src_block_ids ,
2485- device = src_device ,
2486- dtype = torch .int64 )
2487- dst_indices = torch .tensor (dst_block_ids ,
2488- device = dst_device ,
2489- dtype = torch .int64 )
2485+
2486+ for idx in range (len (src_block_ids )):
2487+ src_block_id = src_block_ids [idx ]
2488+ src_indices = torch .range (block_size * src_block_id , block_size * (1 + src_block_id ))
2489+ dst_block_id = dst_block_ids [idx ]
2490+ dst_indices = torch .range (block_size * dst_block_id , block_size * (1 + dst_block_id ))
2491+
2492+
24902493 return src_indices , dst_indices
24912494
24922495
@@ -2517,26 +2520,29 @@ def copy_kv_blocks(
25172520 src_block_ids : list [int ],
25182521 dst_block_ids : list [int ],
25192522 direction : Literal ["h2d" , "d2h" ],
2523+ block_size : int
25202524) -> None :
25212525 """Copy kv blocks between different buffers."""
25222526 if not src_kv_caches or not dst_kv_caches or \
25232527 not src_block_ids or not dst_block_ids or \
25242528 len (src_block_ids ) != len (dst_block_ids ):
25252529 return
2526-
2530+ assert len ( src_block_ids ) == len ( dst_block_ids )
25272531 src_device = next (iter (src_kv_caches .values ())).device
25282532 dst_device = next (iter (dst_kv_caches .values ())).device
25292533
25302534 src_indices , dst_indices = _make_src_and_dst_indices (
25312535 src_block_ids = src_block_ids ,
25322536 dst_block_ids = dst_block_ids ,
25332537 src_device = src_device ,
2534- dst_device = dst_device )
2535-
2536- _copy_fn = _insert_blocks_to_hpu if direction == "h2d" else \
2537- _swap_out_hpu_blocks
2538- for layer_name in src_kv_caches :
2539- src_tensor = src_kv_caches [layer_name ]
2540- dst_tensor = dst_kv_caches [layer_name ]
2541- _copy_fn (src_tensor , dst_tensor , src_indices , dst_indices )
2538+ dst_device = dst_device ,
2539+ block_size )
2540+
2541+ for idx , (layer , kv_layer ) in enumerate (src_kv_caches ):
2542+ if direction == "h2d" :
2543+ k , v = kv_layer [0 ], kv_layer [1 ]
2544+ else :
2545+ k , v = kv_layer
2546+ dst_kv_caches [layer ][0 ][dst_indices ].copy_ (k [src_indices ], non_blocking = False )
2547+ dst_kv_caches [layer ][1 ][dst_indices ].copy_ (v [src_indices ], non_blocking = False )
25422548
0 commit comments