88import os
99import time
1010from dataclasses import dataclass , field , fields
11- from typing import TYPE_CHECKING , Any , Callable , Optional , TypeAlias , Union
11+ from typing import TYPE_CHECKING , Any , Callable , Optional , TypeAlias , Union , Literal
1212
1313import habana_frameworks .torch as htorch
1414import habana_frameworks .torch .internal .bridge_config as bc
@@ -865,7 +865,7 @@ def _get_prompts_and_decodes(
865865 assert num_reqs > 0
866866
867867 if scheduler_output .kv_connector_metadata :
868- requests = scheduler_output .kv_connector_metadata .requests
868+ requests = scheduler_output .kv_connector_metadata .reqs_to_save
869869 else :
870870 requests = None
871871
@@ -878,9 +878,9 @@ def _get_prompts_and_decodes(
878878
879879 if requests is not None and req_id not in self .input_batch .req_type :
880880 for request in requests :
881- if request . req_id == req_id :
881+ if request == req_id :
882882 self .input_batch .req_type [req_id ] = "prefill" \
883- if request . load_spec is None else "decode"
883+ if request is not None else "decode"
884884 break
885885
886886 num_computed_tokens = self .input_batch .num_computed_tokens_cpu [i ]
@@ -2433,6 +2433,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
24332433 #import remote_pdb; remote_pdb.set_trace()
24342434 kv_caches = { layer : torch .stack ((tup [0 ], tup [1 ])) for layer ,tup in kv_caches .items ()}
24352435 get_kv_transfer_group ().register_kv_caches (kv_caches )
2436+ get_kv_transfer_group ().set_host_xfer_buffer_ops (copy_kv_blocks )
24362437
24372438 htorch .hpu .synchronize ()
24382439
@@ -2473,3 +2474,69 @@ def kv_connector_no_forward(
24732474 output .finished_sending = finished_sending
24742475 output .finished_recving = finished_recving
24752476 return output
2477+
2478+ def _make_src_and_dst_indices (
2479+ src_block_ids : list [int ],
2480+ dst_block_ids : list [int ],
2481+ src_device : Union [torch .device , str ],
2482+ dst_device : Union [torch .device , str ],
2483+ ) -> tuple [torch .Tensor , torch .Tensor ]:
2484+ src_indices = torch .tensor (src_block_ids ,
2485+ device = src_device ,
2486+ dtype = torch .int64 )
2487+ dst_indices = torch .tensor (dst_block_ids ,
2488+ device = dst_device ,
2489+ dtype = torch .int64 )
2490+ return src_indices , dst_indices
2491+
2492+
2493+ def _insert_blocks_to_hpu (
2494+ cpu_cache : torch .Tensor ,
2495+ hpu_cache : torch .Tensor ,
2496+ cpu_block_indices : torch .Tensor ,
2497+ hpu_block_indices : torch .Tensor ,
2498+ ) -> None :
2499+ torch .ops .xla .dynamo_set_buffer_donor_ (hpu_cache , True )
2500+ hpu_cache [hpu_block_indices ] = cpu_cache [cpu_block_indices ].to (
2501+ hpu_cache .device )
2502+
2503+
2504+ def _swap_out_hpu_blocks (
2505+ hpu_cache : torch .Tensor ,
2506+ cpu_cache : torch .Tensor ,
2507+ hpu_block_indices : torch .Tensor ,
2508+ cpu_block_indices : torch .Tensor ,
2509+ ) -> None :
2510+ """ tpu blocks to cpu blocks"""
2511+ torch .ops .xla .dynamo_set_buffer_donor_ (hpu_cache , True )
2512+ cpu_cache [cpu_block_indices ] = hpu_cache [hpu_block_indices ].cpu ()
2513+
2514+ def copy_kv_blocks (
2515+ src_kv_caches : dict [str , torch .Tensor ],
2516+ dst_kv_caches : dict [str , torch .Tensor ],
2517+ src_block_ids : list [int ],
2518+ dst_block_ids : list [int ],
2519+ direction : Literal ["h2d" , "d2h" ],
2520+ ) -> None :
2521+ """Copy kv blocks between different buffers."""
2522+ if not src_kv_caches or not dst_kv_caches or \
2523+ not src_block_ids or not dst_block_ids or \
2524+ len (src_block_ids ) != len (dst_block_ids ):
2525+ return
2526+
2527+ src_device = next (iter (src_kv_caches .values ())).device
2528+ dst_device = next (iter (dst_kv_caches .values ())).device
2529+
2530+ src_indices , dst_indices = _make_src_and_dst_indices (
2531+ src_block_ids = src_block_ids ,
2532+ dst_block_ids = dst_block_ids ,
2533+ src_device = src_device ,
2534+ dst_device = dst_device )
2535+
2536+ _copy_fn = _insert_blocks_to_hpu if direction == "h2d" else \
2537+ _swap_out_hpu_blocks
2538+ for layer_name in src_kv_caches :
2539+ src_tensor = src_kv_caches [layer_name ]
2540+ dst_tensor = dst_kv_caches [layer_name ]
2541+ _copy_fn (src_tensor , dst_tensor , src_indices , dst_indices )
2542+
0 commit comments