Skip to content

Commit bf3e6b0

Browse files
authored
Port high-level profiler to V1 engine (#1501)
1 parent d9aa3c1 commit bf3e6b0

File tree

5 files changed

+179
-196
lines changed

5 files changed

+179
-196
lines changed

requirements/hpu.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ ray
77
triton==3.1.0
88
setuptools>=77.0.3
99
setuptools-scm>=8
10-
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@d928f25
10+
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@cbe274b
1111

1212
# Dependencies for HPU vllm docker image
1313
datasets

vllm/v1/worker/hpu_model_runner.py

Lines changed: 125 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
import torch.distributed
1717
import vllm_hpu_extension.environment as environment
1818
from vllm_hpu_extension.bucketing.common import HPUBucketingManager
19-
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
19+
from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
20+
HabanaMemoryProfiler,
21+
HabanaProfilerCounterHelper,
22+
format_bytes)
2023
from vllm_hpu_extension.runtime import get_config
2124

2225
from vllm.attention.backends.abstract import AttentionType
@@ -525,6 +528,7 @@ def __init__(
525528
self,
526529
vllm_config: VllmConfig,
527530
device: torch.device = 'hpu',
531+
is_driver_worker: bool = False,
528532
):
529533
# TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config)
530534
environment.set_vllm_config(vllm_config)
@@ -538,6 +542,7 @@ def __init__(
538542
self.speculative_config = vllm_config.speculative_config
539543
self.prompt_adapter_config = vllm_config.prompt_adapter_config
540544
self.observability_config = vllm_config.observability_config
545+
self.is_driver_worker = is_driver_worker
541546

542547
self.sampler = get_sampler()
543548

@@ -636,6 +641,9 @@ def __init__(
636641
# TODO(madamczyk-intel): add a knob for that
637642
# TODO(madamczyk-intel): debug why increasing it lowers acc
638643
self.logits_rounding = 1
644+
# High-level profiler
645+
self.profiler = HabanaHighLevelProfiler()
646+
self.profiler_counter_helper = HabanaProfilerCounterHelper()
639647

640648
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
641649
"""
@@ -841,6 +849,7 @@ def _get_prompts_and_decodes(
841849

842850
# Traverse decodes first
843851
decode_req_ids = []
852+
num_computed_tokens_decode = []
844853
for i in range(num_reqs):
845854
req_id = self.input_batch.req_ids[i]
846855
assert req_id is not None
@@ -857,6 +866,11 @@ def _get_prompts_and_decodes(
857866
# This is decode
858867
assert num_scheduled_tokens == 1
859868
decode_req_ids.append(req_id)
869+
num_computed_tokens_decode.append(int(num_computed_tokens + 1))
870+
871+
if self.profiler.enabled:
872+
self.profiler_counter_helper.capture_decode_seq_stats(
873+
num_computed_tokens_decode)
860874

861875
# Traverse prompts
862876
prompt_req_ids = []
@@ -1071,6 +1085,8 @@ def _form_prefill_batch(self, contents):
10711085
token_ids = contents.token_ids
10721086
req_ids = contents.req_ids
10731087
query_lens = [len(tids) for tids in contents.token_ids]
1088+
if self.profiler.enabled:
1089+
self.profiler_counter_helper.capture_prompt_seq_stats(query_lens)
10741090
context_lens = contents.context_lens
10751091

10761092
token_positions = [
@@ -1375,17 +1391,31 @@ def _execute_model_generic(self,
13751391
# no hpu graphs for t.compile?
13761392
use_graphs = False
13771393
trimmed_attn_metadata = trim_attn_metadata(attn_metadata)
1378-
hidden_states = self.model.forward(input_ids=token_ids,
1379-
positions=position_ids,
1380-
attn_metadata=trimmed_attn_metadata,
1381-
kv_caches=kv_caches)
1394+
if self.is_driver_worker:
1395+
model_event_name = ("model_forward_"
1396+
f"bs{batch_size}_"
1397+
f"seq{seq_len}_"
1398+
f"ctx{num_blocks}_"
1399+
f"graphs{'T' if use_graphs else 'F'}")
1400+
else:
1401+
model_event_name = 'model_executable'
1402+
with self.profiler.record_event('internal', model_event_name):
1403+
hidden_states = self.model.forward(
1404+
input_ids=token_ids,
1405+
positions=position_ids,
1406+
attn_metadata=trimmed_attn_metadata,
1407+
kv_caches=kv_caches)
13821408
# NOTE(kzawora): returning hidden_states is required in prompt logprobs
13831409
# scenarios, as they will do logit processing on their own
13841410
non_flattened_hidden_states = hidden_states
13851411

13861412
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
13871413
hidden_states = hidden_states[logits_indices]
1388-
logits = self.model.compute_logits(hidden_states, None)
1414+
with self.profiler.record_event('internal', ('compute_logits'
1415+
f'{batch_size}_'
1416+
f'seq{seq_len}_ctx'
1417+
f'{num_blocks}')):
1418+
logits = self.model.compute_logits(hidden_states, None)
13891419
return non_flattened_hidden_states, logits
13901420

13911421
def _get_prompt_logprobs_dict(
@@ -1532,8 +1562,9 @@ def execute_model(
15321562
num_decodes = len(pd_info.decode_req_ids)
15331563
num_prefills = len(pd_info.prompt_req_ids)
15341564
num_reqs = num_decodes + num_prefills
1535-
prefill_data, decode_data = self._prepare_inputs(
1536-
scheduler_output, num_prefills, num_decodes)
1565+
with self.profiler.record_event('internal', 'prepare_input_tensors'):
1566+
prefill_data, decode_data = self._prepare_inputs(
1567+
scheduler_output, num_prefills, num_decodes)
15371568

15381569
#FIXME(kzawora): Currently there's no handling of logprobs. Fix that
15391570
# later.
@@ -1548,64 +1579,102 @@ def execute_model(
15481579
attn_metadata, logits_indices,
15491580
logits_requests) in enumerate(
15501581
zip(*shallow_tuple(prefill_data))):
1582+
self.event_start = self.profiler.get_timestamp_us()
1583+
self.profiler.start("internal", "prefill")
15511584
htorch.core.mark_step()
15521585
prefill_hidden_states_ts, logits_device = \
15531586
self._execute_model_generic(
15541587
token_ids, position_ids, attn_metadata, logits_indices,
15551588
self.kv_caches)
15561589
htorch.core.mark_step()
1557-
sampling_metadata = self._prepare_sampling(
1558-
batch_changed, req_id, pad_to=logits_device.shape[0])
1559-
sampler_output = self.sampler(
1560-
logits=logits_device, sampling_metadata=sampling_metadata)
1561-
prefill_sampled_token_ids.append(
1562-
sampler_output.sampled_token_ids.flatten())
1563-
prefill_sampled_requests.extend(logits_requests)
1590+
with self.profiler.record_event('internal', "sampler"):
1591+
sampling_metadata = self._prepare_sampling(
1592+
batch_changed, req_id, pad_to=logits_device.shape[0])
1593+
sampler_output = self.sampler(
1594+
logits=logits_device,
1595+
sampling_metadata=sampling_metadata)
1596+
prefill_sampled_token_ids.append(
1597+
sampler_output.sampled_token_ids.flatten())
1598+
prefill_sampled_requests.extend(logits_requests)
15641599
htorch.core.mark_step()
1600+
if self.is_driver_worker and self.profiler.enabled:
1601+
# Stop recording 'execute_model_generic' event
1602+
self.profiler.end()
1603+
event_end = self.profiler.get_timestamp_us()
1604+
counters = self.profiler_counter_helper.get_counter_dict(
1605+
cache_config=self.cache_config,
1606+
duration=event_end - self.event_start,
1607+
seq_len=self._seq_len(attn_metadata),
1608+
batch_size_padded=token_ids.size(0),
1609+
real_batch_size=len(req_id),
1610+
prompt_batch_idx=idx,
1611+
is_prompt=True)
1612+
self.profiler.record_counter(self.event_start, counters)
1613+
if self.is_driver_worker and self.profiler.enabled:
1614+
self.profiler_counter_helper.reset_prompt_seq_stats()
15651615

15661616
######################### DECODES #########################
15671617
# Decodes run as one single batch with [padded_decode_bs, 1]
15681618
if num_decodes > 0:
1619+
self.event_start = self.profiler.get_timestamp_us()
1620+
self.profiler.start("internal", "decode")
15691621
assert decode_data is not None
15701622
htorch.core.mark_step()
15711623
_, logits_device = self._execute_model_generic(
15721624
decode_data.token_ids, decode_data.position_ids,
15731625
decode_data.attn_metadata, decode_data.logits_indices,
15741626
self.kv_caches)
15751627
htorch.core.mark_step()
1576-
sampling_metadata = self._prepare_sampling(
1577-
batch_changed,
1578-
pd_info.decode_req_ids,
1579-
pad_to=logits_device.shape[0])
1580-
sampler_output = self.sampler(logits=logits_device,
1581-
sampling_metadata=sampling_metadata)
1582-
decode_sampled_token_ids.append(
1583-
sampler_output.sampled_token_ids.flatten())
1584-
decode_sampled_requests.extend(
1585-
self.input_batch.req_ids[:num_decodes])
1628+
with self.profiler.record_event('internal', "sampler"):
1629+
sampling_metadata = self._prepare_sampling(
1630+
batch_changed,
1631+
pd_info.decode_req_ids,
1632+
pad_to=logits_device.shape[0])
1633+
sampler_output = self.sampler(
1634+
logits=logits_device, sampling_metadata=sampling_metadata)
1635+
decode_sampled_token_ids.append(
1636+
sampler_output.sampled_token_ids.flatten())
1637+
decode_sampled_requests.extend(
1638+
self.input_batch.req_ids[:num_decodes])
15861639
htorch.core.mark_step()
1640+
if self.is_driver_worker and self.profiler.enabled:
1641+
# Stop recording 'execute_model' event
1642+
self.profiler.end()
1643+
event_end = self.profiler.get_timestamp_us()
1644+
counters = self.profiler_counter_helper.get_counter_dict(
1645+
cache_config=self.cache_config,
1646+
duration=event_end - self.event_start,
1647+
seq_len=self._seq_len(decode_data.attn_metadata),
1648+
batch_size_padded= \
1649+
decode_data.token_ids.size(0), # type: ignore
1650+
real_batch_size=decode_data.num_decodes,
1651+
prompt_batch_idx=None,
1652+
is_prompt=False)
1653+
self.profiler.record_counter(self.event_start, counters)
15871654
# From this point onward, all operations are done on CPU.
15881655
# We already have tokens. Let's copy the data to
15891656
# CPU as is, and then discard padded tokens.
1590-
1591-
prefill_sampled_token_ids = [
1592-
tensor.cpu() for tensor in prefill_sampled_token_ids
1593-
]
1594-
decode_sampled_token_ids = [
1595-
tensor.cpu()[:num_decodes] for tensor in decode_sampled_token_ids
1596-
]
1597-
sampled_token_ids_list = torch.cat(decode_sampled_token_ids +
1598-
prefill_sampled_token_ids).tolist()
1599-
sampled_token_requests = \
1600-
decode_sampled_requests + prefill_sampled_requests
1601-
max_req_index = max(self.input_batch.req_id_to_index.values())
1602-
postprocessed_sampled_token_ids: list[list]
1603-
postprocessed_sampled_token_ids = [[]
1604-
for _ in range(max_req_index + 1)]
1605-
for tok_id, req_id in zip(sampled_token_ids_list,
1606-
sampled_token_requests):
1607-
postprocessed_sampled_token_ids[
1608-
self.input_batch.req_id_to_index[req_id]].append(tok_id)
1657+
with self.profiler.record_event('internal', "sampler_postprocessing"):
1658+
prefill_sampled_token_ids = [
1659+
tensor.cpu() for tensor in prefill_sampled_token_ids
1660+
]
1661+
decode_sampled_token_ids = [
1662+
tensor.cpu()[:num_decodes]
1663+
for tensor in decode_sampled_token_ids
1664+
]
1665+
sampled_token_ids_list = torch.cat(
1666+
decode_sampled_token_ids + prefill_sampled_token_ids).tolist()
1667+
sampled_token_requests = \
1668+
decode_sampled_requests + prefill_sampled_requests
1669+
max_req_index = max(self.input_batch.req_id_to_index.values())
1670+
postprocessed_sampled_token_ids: list[list]
1671+
postprocessed_sampled_token_ids = [[]
1672+
for _ in range(max_req_index +
1673+
1)]
1674+
for tok_id, req_id in zip(sampled_token_ids_list,
1675+
sampled_token_requests):
1676+
postprocessed_sampled_token_ids[
1677+
self.input_batch.req_id_to_index[req_id]].append(tok_id)
16091678

16101679
# NOTE(kzawora): idk what happens if part of batch doesn't have logprobs
16111680

@@ -1796,6 +1865,14 @@ def warmup_scenario(self,
17961865
slot_mapping_device = _async_h2d_tensor_copy(slot_mapping, self.device)
17971866

17981867
use_graphs = self._use_graphs()
1868+
phase = "prompt" if is_prompt else "decode"
1869+
scenario_name = ("warmup_"
1870+
f"{phase}_"
1871+
f"bs{batch_size}_"
1872+
f"seq{query_seq_len}_"
1873+
f"ctx{num_blocks}_"
1874+
f"graphs{'T' if use_graphs else 'F'}")
1875+
17991876
input_ids = torch.zeros((batch_size, query_seq_len),
18001877
dtype=torch.int32,
18011878
device='cpu')
@@ -1809,6 +1886,7 @@ def warmup_scenario(self,
18091886
input_ids_device = _async_h2d_tensor_copy(input_ids, self.device)
18101887
position_ids_device = _async_h2d_tensor_copy(position_ids, self.device)
18111888
slot_mapping_device = _async_h2d_tensor_copy(slot_mapping, self.device)
1889+
self.profiler.start('internal', scenario_name)
18121890
times = 3 if use_graphs or is_pt_profiler_run else 1
18131891
for time_index in range(times):
18141892
if is_prompt:
@@ -1882,6 +1960,7 @@ def warmup_scenario(self,
18821960
} # NOTE(kzawora): idk what to set here
18831961
max_num_logprobs = 0 # NOTE(kzawora): idk what to set here
18841962
# NOTE(kzawora: do this in a smarter way)
1963+
self.profiler.end()
18851964
return None
18861965
htorch.core.mark_step()
18871966
sampling_metadata = SamplingMetadata(
@@ -2138,6 +2217,7 @@ def warmup_model(self) -> None:
21382217
logger.info("Skipping warmup...")
21392218
return
21402219

2220+
self.profiler.start('internal', 'warmup')
21412221
start_mem = HabanaMemoryProfiler.current_device_memory_usage()
21422222
start_time = time.perf_counter()
21432223

@@ -2189,6 +2269,7 @@ def warmup_model(self) -> None:
21892269
f"Warmup finished in {elapsed_time:.0f} secs, "
21902270
f"allocated {format_bytes(end_mem - start_mem)} of device memory")
21912271
logger.info(msg)
2272+
self.profiler.end()
21922273

21932274
def shutdown_inc(self):
21942275
can_finalize_inc = self._is_quant_with_inc() and \

0 commit comments

Comments
 (0)