1616import torch .distributed
1717import vllm_hpu_extension .environment as environment
1818from vllm_hpu_extension .bucketing .common import HPUBucketingManager
19- from vllm_hpu_extension .profiler import HabanaMemoryProfiler , format_bytes
19+ from vllm_hpu_extension .profiler import (HabanaHighLevelProfiler ,
20+ HabanaMemoryProfiler ,
21+ HabanaProfilerCounterHelper ,
22+ format_bytes )
2023from vllm_hpu_extension .runtime import get_config
2124
2225from vllm .attention .backends .abstract import AttentionType
@@ -525,6 +528,7 @@ def __init__(
525528 self ,
526529 vllm_config : VllmConfig ,
527530 device : torch .device = 'hpu' ,
531+ is_driver_worker : bool = False ,
528532 ):
529533 # TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config)
530534 environment .set_vllm_config (vllm_config )
@@ -538,6 +542,7 @@ def __init__(
538542 self .speculative_config = vllm_config .speculative_config
539543 self .prompt_adapter_config = vllm_config .prompt_adapter_config
540544 self .observability_config = vllm_config .observability_config
545+ self .is_driver_worker = is_driver_worker
541546
542547 self .sampler = get_sampler ()
543548
@@ -636,6 +641,9 @@ def __init__(
636641 # TODO(madamczyk-intel): add a knob for that
637642 # TODO(madamczyk-intel): debug why increasing it lowers acc
638643 self .logits_rounding = 1
644+ # High-level profiler
645+ self .profiler = HabanaHighLevelProfiler ()
646+ self .profiler_counter_helper = HabanaProfilerCounterHelper ()
639647
640648 def get_kv_cache_spec (self ) -> dict [str , KVCacheSpec ]:
641649 """
@@ -841,6 +849,7 @@ def _get_prompts_and_decodes(
841849
842850 # Traverse decodes first
843851 decode_req_ids = []
852+ num_computed_tokens_decode = []
844853 for i in range (num_reqs ):
845854 req_id = self .input_batch .req_ids [i ]
846855 assert req_id is not None
@@ -857,6 +866,11 @@ def _get_prompts_and_decodes(
857866 # This is decode
858867 assert num_scheduled_tokens == 1
859868 decode_req_ids .append (req_id )
869+ num_computed_tokens_decode .append (int (num_computed_tokens + 1 ))
870+
871+ if self .profiler .enabled :
872+ self .profiler_counter_helper .capture_decode_seq_stats (
873+ num_computed_tokens_decode )
860874
861875 # Traverse prompts
862876 prompt_req_ids = []
@@ -1071,6 +1085,8 @@ def _form_prefill_batch(self, contents):
10711085 token_ids = contents .token_ids
10721086 req_ids = contents .req_ids
10731087 query_lens = [len (tids ) for tids in contents .token_ids ]
1088+ if self .profiler .enabled :
1089+ self .profiler_counter_helper .capture_prompt_seq_stats (query_lens )
10741090 context_lens = contents .context_lens
10751091
10761092 token_positions = [
@@ -1375,17 +1391,31 @@ def _execute_model_generic(self,
13751391 # no hpu graphs for t.compile?
13761392 use_graphs = False
13771393 trimmed_attn_metadata = trim_attn_metadata (attn_metadata )
1378- hidden_states = self .model .forward (input_ids = token_ids ,
1379- positions = position_ids ,
1380- attn_metadata = trimmed_attn_metadata ,
1381- kv_caches = kv_caches )
1394+ if self .is_driver_worker :
1395+ model_event_name = ("model_forward_"
1396+ f"bs{ batch_size } _"
1397+ f"seq{ seq_len } _"
1398+ f"ctx{ num_blocks } _"
1399+ f"graphs{ 'T' if use_graphs else 'F' } " )
1400+ else :
1401+ model_event_name = 'model_executable'
1402+ with self .profiler .record_event ('internal' , model_event_name ):
1403+ hidden_states = self .model .forward (
1404+ input_ids = token_ids ,
1405+ positions = position_ids ,
1406+ attn_metadata = trimmed_attn_metadata ,
1407+ kv_caches = kv_caches )
13821408 # NOTE(kzawora): returning hidden_states is required in prompt logprobs
13831409 # scenarios, as they will do logit processing on their own
13841410 non_flattened_hidden_states = hidden_states
13851411
13861412 hidden_states = hidden_states .view (- 1 , hidden_states .shape [- 1 ])
13871413 hidden_states = hidden_states [logits_indices ]
1388- logits = self .model .compute_logits (hidden_states , None )
1414+ with self .profiler .record_event ('internal' , ('compute_logits'
1415+ f'{ batch_size } _'
1416+ f'seq{ seq_len } _ctx'
1417+ f'{ num_blocks } ' )):
1418+ logits = self .model .compute_logits (hidden_states , None )
13891419 return non_flattened_hidden_states , logits
13901420
13911421 def _get_prompt_logprobs_dict (
@@ -1532,8 +1562,9 @@ def execute_model(
15321562 num_decodes = len (pd_info .decode_req_ids )
15331563 num_prefills = len (pd_info .prompt_req_ids )
15341564 num_reqs = num_decodes + num_prefills
1535- prefill_data , decode_data = self ._prepare_inputs (
1536- scheduler_output , num_prefills , num_decodes )
1565+ with self .profiler .record_event ('internal' , 'prepare_input_tensors' ):
1566+ prefill_data , decode_data = self ._prepare_inputs (
1567+ scheduler_output , num_prefills , num_decodes )
15371568
15381569 #FIXME(kzawora): Currently there's no handling of logprobs. Fix that
15391570 # later.
@@ -1548,64 +1579,102 @@ def execute_model(
15481579 attn_metadata , logits_indices ,
15491580 logits_requests ) in enumerate (
15501581 zip (* shallow_tuple (prefill_data ))):
1582+ self .event_start = self .profiler .get_timestamp_us ()
1583+ self .profiler .start ("internal" , "prefill" )
15511584 htorch .core .mark_step ()
15521585 prefill_hidden_states_ts , logits_device = \
15531586 self ._execute_model_generic (
15541587 token_ids , position_ids , attn_metadata , logits_indices ,
15551588 self .kv_caches )
15561589 htorch .core .mark_step ()
1557- sampling_metadata = self ._prepare_sampling (
1558- batch_changed , req_id , pad_to = logits_device .shape [0 ])
1559- sampler_output = self .sampler (
1560- logits = logits_device , sampling_metadata = sampling_metadata )
1561- prefill_sampled_token_ids .append (
1562- sampler_output .sampled_token_ids .flatten ())
1563- prefill_sampled_requests .extend (logits_requests )
1590+ with self .profiler .record_event ('internal' , "sampler" ):
1591+ sampling_metadata = self ._prepare_sampling (
1592+ batch_changed , req_id , pad_to = logits_device .shape [0 ])
1593+ sampler_output = self .sampler (
1594+ logits = logits_device ,
1595+ sampling_metadata = sampling_metadata )
1596+ prefill_sampled_token_ids .append (
1597+ sampler_output .sampled_token_ids .flatten ())
1598+ prefill_sampled_requests .extend (logits_requests )
15641599 htorch .core .mark_step ()
1600+ if self .is_driver_worker and self .profiler .enabled :
1601+ # Stop recording 'execute_model_generic' event
1602+ self .profiler .end ()
1603+ event_end = self .profiler .get_timestamp_us ()
1604+ counters = self .profiler_counter_helper .get_counter_dict (
1605+ cache_config = self .cache_config ,
1606+ duration = event_end - self .event_start ,
1607+ seq_len = self ._seq_len (attn_metadata ),
1608+ batch_size_padded = token_ids .size (0 ),
1609+ real_batch_size = len (req_id ),
1610+ prompt_batch_idx = idx ,
1611+ is_prompt = True )
1612+ self .profiler .record_counter (self .event_start , counters )
1613+ if self .is_driver_worker and self .profiler .enabled :
1614+ self .profiler_counter_helper .reset_prompt_seq_stats ()
15651615
15661616 ######################### DECODES #########################
15671617 # Decodes run as one single batch with [padded_decode_bs, 1]
15681618 if num_decodes > 0 :
1619+ self .event_start = self .profiler .get_timestamp_us ()
1620+ self .profiler .start ("internal" , "decode" )
15691621 assert decode_data is not None
15701622 htorch .core .mark_step ()
15711623 _ , logits_device = self ._execute_model_generic (
15721624 decode_data .token_ids , decode_data .position_ids ,
15731625 decode_data .attn_metadata , decode_data .logits_indices ,
15741626 self .kv_caches )
15751627 htorch .core .mark_step ()
1576- sampling_metadata = self ._prepare_sampling (
1577- batch_changed ,
1578- pd_info .decode_req_ids ,
1579- pad_to = logits_device .shape [0 ])
1580- sampler_output = self .sampler (logits = logits_device ,
1581- sampling_metadata = sampling_metadata )
1582- decode_sampled_token_ids .append (
1583- sampler_output .sampled_token_ids .flatten ())
1584- decode_sampled_requests .extend (
1585- self .input_batch .req_ids [:num_decodes ])
1628+ with self .profiler .record_event ('internal' , "sampler" ):
1629+ sampling_metadata = self ._prepare_sampling (
1630+ batch_changed ,
1631+ pd_info .decode_req_ids ,
1632+ pad_to = logits_device .shape [0 ])
1633+ sampler_output = self .sampler (
1634+ logits = logits_device , sampling_metadata = sampling_metadata )
1635+ decode_sampled_token_ids .append (
1636+ sampler_output .sampled_token_ids .flatten ())
1637+ decode_sampled_requests .extend (
1638+ self .input_batch .req_ids [:num_decodes ])
15861639 htorch .core .mark_step ()
1640+ if self .is_driver_worker and self .profiler .enabled :
1641+ # Stop recording 'execute_model' event
1642+ self .profiler .end ()
1643+ event_end = self .profiler .get_timestamp_us ()
1644+ counters = self .profiler_counter_helper .get_counter_dict (
1645+ cache_config = self .cache_config ,
1646+ duration = event_end - self .event_start ,
1647+ seq_len = self ._seq_len (decode_data .attn_metadata ),
1648+ batch_size_padded = \
1649+ decode_data .token_ids .size (0 ), # type: ignore
1650+ real_batch_size = decode_data .num_decodes ,
1651+ prompt_batch_idx = None ,
1652+ is_prompt = False )
1653+ self .profiler .record_counter (self .event_start , counters )
15871654 # From this point onward, all operations are done on CPU.
15881655 # We already have tokens. Let's copy the data to
15891656 # CPU as is, and then discard padded tokens.
1590-
1591- prefill_sampled_token_ids = [
1592- tensor .cpu () for tensor in prefill_sampled_token_ids
1593- ]
1594- decode_sampled_token_ids = [
1595- tensor .cpu ()[:num_decodes ] for tensor in decode_sampled_token_ids
1596- ]
1597- sampled_token_ids_list = torch .cat (decode_sampled_token_ids +
1598- prefill_sampled_token_ids ).tolist ()
1599- sampled_token_requests = \
1600- decode_sampled_requests + prefill_sampled_requests
1601- max_req_index = max (self .input_batch .req_id_to_index .values ())
1602- postprocessed_sampled_token_ids : list [list ]
1603- postprocessed_sampled_token_ids = [[]
1604- for _ in range (max_req_index + 1 )]
1605- for tok_id , req_id in zip (sampled_token_ids_list ,
1606- sampled_token_requests ):
1607- postprocessed_sampled_token_ids [
1608- self .input_batch .req_id_to_index [req_id ]].append (tok_id )
1657+ with self .profiler .record_event ('internal' , "sampler_postprocessing" ):
1658+ prefill_sampled_token_ids = [
1659+ tensor .cpu () for tensor in prefill_sampled_token_ids
1660+ ]
1661+ decode_sampled_token_ids = [
1662+ tensor .cpu ()[:num_decodes ]
1663+ for tensor in decode_sampled_token_ids
1664+ ]
1665+ sampled_token_ids_list = torch .cat (
1666+ decode_sampled_token_ids + prefill_sampled_token_ids ).tolist ()
1667+ sampled_token_requests = \
1668+ decode_sampled_requests + prefill_sampled_requests
1669+ max_req_index = max (self .input_batch .req_id_to_index .values ())
1670+ postprocessed_sampled_token_ids : list [list ]
1671+ postprocessed_sampled_token_ids = [[]
1672+ for _ in range (max_req_index +
1673+ 1 )]
1674+ for tok_id , req_id in zip (sampled_token_ids_list ,
1675+ sampled_token_requests ):
1676+ postprocessed_sampled_token_ids [
1677+ self .input_batch .req_id_to_index [req_id ]].append (tok_id )
16091678
16101679 # NOTE(kzawora): idk what happens if part of batch doesn't have logprobs
16111680
@@ -1796,6 +1865,14 @@ def warmup_scenario(self,
17961865 slot_mapping_device = _async_h2d_tensor_copy (slot_mapping , self .device )
17971866
17981867 use_graphs = self ._use_graphs ()
1868+ phase = "prompt" if is_prompt else "decode"
1869+ scenario_name = ("warmup_"
1870+ f"{ phase } _"
1871+ f"bs{ batch_size } _"
1872+ f"seq{ query_seq_len } _"
1873+ f"ctx{ num_blocks } _"
1874+ f"graphs{ 'T' if use_graphs else 'F' } " )
1875+
17991876 input_ids = torch .zeros ((batch_size , query_seq_len ),
18001877 dtype = torch .int32 ,
18011878 device = 'cpu' )
@@ -1809,6 +1886,7 @@ def warmup_scenario(self,
18091886 input_ids_device = _async_h2d_tensor_copy (input_ids , self .device )
18101887 position_ids_device = _async_h2d_tensor_copy (position_ids , self .device )
18111888 slot_mapping_device = _async_h2d_tensor_copy (slot_mapping , self .device )
1889+ self .profiler .start ('internal' , scenario_name )
18121890 times = 3 if use_graphs or is_pt_profiler_run else 1
18131891 for time_index in range (times ):
18141892 if is_prompt :
@@ -1882,6 +1960,7 @@ def warmup_scenario(self,
18821960 } # NOTE(kzawora): idk what to set here
18831961 max_num_logprobs = 0 # NOTE(kzawora): idk what to set here
18841962 # NOTE(kzawora: do this in a smarter way)
1963+ self .profiler .end ()
18851964 return None
18861965 htorch .core .mark_step ()
18871966 sampling_metadata = SamplingMetadata (
@@ -2138,6 +2217,7 @@ def warmup_model(self) -> None:
21382217 logger .info ("Skipping warmup..." )
21392218 return
21402219
2220+ self .profiler .start ('internal' , 'warmup' )
21412221 start_mem = HabanaMemoryProfiler .current_device_memory_usage ()
21422222 start_time = time .perf_counter ()
21432223
@@ -2189,6 +2269,7 @@ def warmup_model(self) -> None:
21892269 f"Warmup finished in { elapsed_time :.0f} secs, "
21902270 f"allocated { format_bytes (end_mem - start_mem )} of device memory" )
21912271 logger .info (msg )
2272+ self .profiler .end ()
21922273
21932274 def shutdown_inc (self ):
21942275 can_finalize_inc = self ._is_quant_with_inc () and \
0 commit comments