1717import torch .distributed
1818import vllm_hpu_extension .environment as environment
1919from vllm_hpu_extension .bucketing .common import HPUBucketingManager
20+ from vllm_hpu_extension .defragmentation import OnlineDefragmenter
2021from vllm_hpu_extension .profiler import (HabanaHighLevelProfiler ,
2122 HabanaMemoryProfiler ,
2223 HabanaProfilerCounterHelper ,
23- format_bytes )
24- from vllm_hpu_extension .runtime import get_config
24+ format_bytes , setup_profiler )
25+ from vllm_hpu_extension .runtime import finalize_config , get_config
26+ from vllm_hpu_extension .utils import pad_list
2527
2628from vllm .attention .backends .abstract import AttentionType
2729from vllm .attention .layer import Attention
5961_TYPE_CACHE : dict [str , dict [str , Any ]] = {}
6062
6163
62- def setup_profiler (warmup , active ):
63- schedule = torch .profiler .schedule (wait = 0 ,
64- warmup = warmup ,
65- active = active ,
66- repeat = 1 )
67- activities = [
68- torch .profiler .ProfilerActivity .CPU ,
69- torch .profiler .ProfilerActivity .HPU
70- ]
71- profiler = torch .profiler .profile (
72- schedule = schedule ,
73- activities = activities ,
74- on_trace_ready = torch .profiler .tensorboard_trace_handler ('.' ,
75- use_gzip = True ),
76- record_shapes = False ,
77- with_stack = True )
78- return profiler
79-
80-
8164@dataclass
8265class PromptDecodeInfo :
8366 prompt_req_ids : list [str ]
@@ -541,13 +524,6 @@ def round_up(value: int, k: int):
541524 return (value + k - 1 ) // k * k
542525
543526
544- def pad_list (input , target_len , val_generator ):
545- padding = target_len - len (input )
546- if padding > 0 :
547- input .extend (itertools .islice (val_generator , padding ))
548- return input
549-
550-
551527class HPUModelRunner :
552528
553529 def __init__ (
@@ -558,6 +534,8 @@ def __init__(
558534 ):
559535 # TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config)
560536 environment .set_vllm_config (vllm_config )
537+ finalize_config ()
538+
561539 self .vllm_config = vllm_config
562540 self .model_config = vllm_config .model_config
563541 self .cache_config = vllm_config .cache_config
@@ -671,6 +649,8 @@ def __init__(
671649 self .profiler = HabanaHighLevelProfiler ()
672650 self .profiler_counter_helper = HabanaProfilerCounterHelper ()
673651
652+ self .defragmenter = OnlineDefragmenter ()
653+
674654 def get_kv_cache_spec (self ) -> dict [str , KVCacheSpec ]:
675655 """
676656 Generates the KVCacheSpec by parsing the kv cache format from each
@@ -1075,6 +1055,7 @@ def _extract_prefill_batch_contents(self, num_prefills, num_decodes,
10751055 num_blocks = round_up (context_len + query_len ,
10761056 self .block_size ) // self .block_size
10771057 blocks = block_table_cpu_tensor [batch_idx , :num_blocks ].tolist ()
1058+ blocks = [self .defragmenter .resolve (b ) for b in blocks ]
10781059
10791060 prompt_tokens = self .input_batch .num_prompt_tokens [batch_idx ]
10801061 #TODO: Fix non-prompt case
@@ -1311,6 +1292,8 @@ def _prepare_decode_inputs(self, num_decodes,
13111292 dim = 1 ,
13121293 index = (index //
13131294 self .block_size ))
1295+ block_number .apply_ (self .defragmenter .resolve )
1296+
13141297 block_offsets = padded_index % self .block_size
13151298 slot_mapping = block_number * self .block_size + block_offsets
13161299 # set an out of range value for the padding tokens so that they
@@ -1320,6 +1303,8 @@ def _prepare_decode_inputs(self, num_decodes,
13201303 range (self ._PAD_SLOT_ID , self ._PAD_SLOT_ID + self .block_size ))
13211304 slot_mapping [num_decodes :].apply_ (lambda _ , ds = dummy_slots : next (ds ))
13221305
1306+ block_tables_list = self .defragmenter .resolve_all (block_tables_list )
1307+
13231308 # CONTEXT_LENS [batch_size]
13241309 block_list , block_groups , block_usage = \
13251310 self .get_habana_paged_attn_buffers (
@@ -1598,6 +1583,20 @@ def execute_model(
15981583 # On CPU, sanitize [tokD0, tokD1, tokD2, 0, tokP0, tokP1, tokP2, 0] -> [tokD0, tokD1, tokD2, tokP0, tokP1, tokP2] # noqa
15991584 # Return [tokD0, tokD1, tokD2, tokP0, tokP1, tokP2]
16001585
1586+ if self .defragmenter .enabled and self .kv_caches :
1587+ new = {
1588+ req .req_id : flatten (req .block_ids )
1589+ for req in scheduler_output .scheduled_new_reqs if req .block_ids
1590+ }
1591+ cached = {
1592+ req .req_id : flatten (req .new_block_ids )
1593+ for req in scheduler_output .scheduled_cached_reqs
1594+ if req .new_block_ids
1595+ }
1596+ self .defragmenter .update_state (new | cached ,
1597+ scheduler_output .finished_req_ids )
1598+ self .defragmenter .defragment ()
1599+
16011600 batch_changed = self ._update_states (scheduler_output )
16021601 if not scheduler_output .total_num_scheduled_tokens :
16031602 if not has_kv_transfer_group ():
@@ -2202,6 +2201,7 @@ def _read_profiling_cfg(self):
22022201
22032202 @torch .inference_mode ()
22042203 def warmup_model (self ) -> None :
2204+ self .defragmenter .initialize (self .kv_caches , self .block_size )
22052205 if not self .enable_bucketing :
22062206 return
22072207 prompt_profile_cfg , decode_profile_cfg = self ._read_profiling_cfg ()
0 commit comments