Skip to content

Commit 046343b

Browse files
[V1] Defragmentation support (#1568)
extension PR: HabanaAI/vllm-hpu-extension#275 --------- Signed-off-by: Michal Adamczyk <[email protected]>
1 parent e9c83fc commit 046343b

File tree

4 files changed

+61
-31
lines changed

4 files changed

+61
-31
lines changed

requirements/hpu.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ ray
77
triton==3.1.0
88
setuptools>=77.0.3
99
setuptools-scm>=8
10-
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@1e96318
10+
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@cd79204
1111

1212
# Dependencies for HPU vllm docker image
1313
datasets

vllm/v1/worker/hpu_model_runner.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
import torch.distributed
1818
import vllm_hpu_extension.environment as environment
1919
from vllm_hpu_extension.bucketing.common import HPUBucketingManager
20+
from vllm_hpu_extension.defragmentation import OnlineDefragmenter
2021
from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
2122
HabanaMemoryProfiler,
2223
HabanaProfilerCounterHelper,
23-
format_bytes)
24-
from vllm_hpu_extension.runtime import get_config
24+
format_bytes, setup_profiler)
25+
from vllm_hpu_extension.runtime import finalize_config, get_config
26+
from vllm_hpu_extension.utils import pad_list
2527

2628
from vllm.attention.backends.abstract import AttentionType
2729
from vllm.attention.layer import Attention
@@ -59,25 +61,6 @@
5961
_TYPE_CACHE: dict[str, dict[str, Any]] = {}
6062

6163

62-
def setup_profiler(warmup, active):
63-
schedule = torch.profiler.schedule(wait=0,
64-
warmup=warmup,
65-
active=active,
66-
repeat=1)
67-
activities = [
68-
torch.profiler.ProfilerActivity.CPU,
69-
torch.profiler.ProfilerActivity.HPU
70-
]
71-
profiler = torch.profiler.profile(
72-
schedule=schedule,
73-
activities=activities,
74-
on_trace_ready=torch.profiler.tensorboard_trace_handler('.',
75-
use_gzip=True),
76-
record_shapes=False,
77-
with_stack=True)
78-
return profiler
79-
80-
8164
@dataclass
8265
class PromptDecodeInfo:
8366
prompt_req_ids: list[str]
@@ -541,13 +524,6 @@ def round_up(value: int, k: int):
541524
return (value + k - 1) // k * k
542525

543526

544-
def pad_list(input, target_len, val_generator):
545-
padding = target_len - len(input)
546-
if padding > 0:
547-
input.extend(itertools.islice(val_generator, padding))
548-
return input
549-
550-
551527
class HPUModelRunner:
552528

553529
def __init__(
@@ -558,6 +534,8 @@ def __init__(
558534
):
559535
# TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config)
560536
environment.set_vllm_config(vllm_config)
537+
finalize_config()
538+
561539
self.vllm_config = vllm_config
562540
self.model_config = vllm_config.model_config
563541
self.cache_config = vllm_config.cache_config
@@ -671,6 +649,8 @@ def __init__(
671649
self.profiler = HabanaHighLevelProfiler()
672650
self.profiler_counter_helper = HabanaProfilerCounterHelper()
673651

652+
self.defragmenter = OnlineDefragmenter()
653+
674654
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
675655
"""
676656
Generates the KVCacheSpec by parsing the kv cache format from each
@@ -1075,6 +1055,7 @@ def _extract_prefill_batch_contents(self, num_prefills, num_decodes,
10751055
num_blocks = round_up(context_len + query_len,
10761056
self.block_size) // self.block_size
10771057
blocks = block_table_cpu_tensor[batch_idx, :num_blocks].tolist()
1058+
blocks = [self.defragmenter.resolve(b) for b in blocks]
10781059

10791060
prompt_tokens = self.input_batch.num_prompt_tokens[batch_idx]
10801061
#TODO: Fix non-prompt case
@@ -1311,6 +1292,8 @@ def _prepare_decode_inputs(self, num_decodes,
13111292
dim=1,
13121293
index=(index //
13131294
self.block_size))
1295+
block_number.apply_(self.defragmenter.resolve)
1296+
13141297
block_offsets = padded_index % self.block_size
13151298
slot_mapping = block_number * self.block_size + block_offsets
13161299
# set an out of range value for the padding tokens so that they
@@ -1320,6 +1303,8 @@ def _prepare_decode_inputs(self, num_decodes,
13201303
range(self._PAD_SLOT_ID, self._PAD_SLOT_ID + self.block_size))
13211304
slot_mapping[num_decodes:].apply_(lambda _, ds=dummy_slots: next(ds))
13221305

1306+
block_tables_list = self.defragmenter.resolve_all(block_tables_list)
1307+
13231308
# CONTEXT_LENS [batch_size]
13241309
block_list, block_groups, block_usage = \
13251310
self.get_habana_paged_attn_buffers(
@@ -1598,6 +1583,20 @@ def execute_model(
15981583
# On CPU, sanitize [tokD0, tokD1, tokD2, 0, tokP0, tokP1, tokP2, 0] -> [tokD0, tokD1, tokD2, tokP0, tokP1, tokP2] # noqa
15991584
# Return [tokD0, tokD1, tokD2, tokP0, tokP1, tokP2]
16001585

1586+
if self.defragmenter.enabled and self.kv_caches:
1587+
new = {
1588+
req.req_id: flatten(req.block_ids)
1589+
for req in scheduler_output.scheduled_new_reqs if req.block_ids
1590+
}
1591+
cached = {
1592+
req.req_id: flatten(req.new_block_ids)
1593+
for req in scheduler_output.scheduled_cached_reqs
1594+
if req.new_block_ids
1595+
}
1596+
self.defragmenter.update_state(new | cached,
1597+
scheduler_output.finished_req_ids)
1598+
self.defragmenter.defragment()
1599+
16011600
batch_changed = self._update_states(scheduler_output)
16021601
if not scheduler_output.total_num_scheduled_tokens:
16031602
if not has_kv_transfer_group():
@@ -2202,6 +2201,7 @@ def _read_profiling_cfg(self):
22022201

22032202
@torch.inference_mode()
22042203
def warmup_model(self) -> None:
2204+
self.defragmenter.initialize(self.kv_caches, self.block_size)
22052205
if not self.enable_bucketing:
22062206
return
22072207
prompt_profile_cfg, decode_profile_cfg = self._read_profiling_cfg()

vllm/v1/worker/hpu_worker.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
import torch
1111
import torch.distributed
1212
import torch.nn as nn
13-
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
13+
from vllm_hpu_extension.debug import init_debug_logger
14+
from vllm_hpu_extension.profiler import (HabanaMemoryProfiler, format_bytes,
15+
setup_profiler)
16+
from vllm_hpu_extension.runtime import get_config
1417

1518
import vllm.envs as envs
1619
from vllm.config import VllmConfig
@@ -32,6 +35,14 @@
3235
from vllm.v1.core.scheduler import SchedulerOutput
3336

3437

38+
def setup_step_profiler(steps):
39+
if steps is None:
40+
return None
41+
step_start, step_end = steps
42+
active = step_end - step_start + 1
43+
return setup_profiler(warmup=0, active=active)
44+
45+
3546
class HPUWorker:
3647

3748
def __init__(
@@ -76,6 +87,10 @@ def __init__(
7687
self.gc_track_recompiles = bool(
7788
"PT_HPU_METRICS_GC_DETAILS" in os.environ
7889
and bool_helper(os.getenv("PT_HPU_METRICS_GC_DETAILS")))
90+
self.step = 0
91+
self.profile_steps = get_config().VLLM_PROFILE_STEPS
92+
self.step_profiler = setup_step_profiler(self.profile_steps)
93+
self.step_debug = init_debug_logger('steps')
7994

8095
def init_profiler(self):
8196
"""Initialize the profiler."""
@@ -254,11 +269,23 @@ def execute_model(
254269
self,
255270
scheduler_output: "SchedulerOutput",
256271
) -> ModelRunnerOutput:
272+
if self.step_debug:
273+
self.step_debug(f'step={self.step}')
274+
if self.step_profiler and self.step == self.profile_steps[0]:
275+
self.step_profiler.start()
257276
with track_graph_compile('HPUWorker.execute_model') \
258277
if self.gc_track_recompiles \
259278
else contextlib.nullcontext():
260279
output = self.model_runner.execute_model(scheduler_output)
261280
# TODO(woosuk): Send the output to the engine process.
281+
if self.step_profiler:
282+
if self.step >= self.profile_steps[0]:
283+
self.step_profiler.step()
284+
if self.step == self.profile_steps[1]:
285+
self.step_profiler.stop()
286+
self.step_profiler = None
287+
raise RuntimeError('Step profiling finished!')
288+
self.step += 1
262289
return output if self.rank == 0 else None
263290

264291
def profile(self, is_start: bool = True):

vllm/worker/hpu_model_runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
HabanaMemoryProfiler,
3737
HabanaProfilerCounterHelper,
3838
format_bytes)
39-
from vllm_hpu_extension.runtime import get_config
39+
from vllm_hpu_extension.runtime import finalize_config, get_config
4040

4141
import vllm.envs as envs
4242
from vllm.attention import AttentionMetadata, get_attn_backend
@@ -968,7 +968,10 @@ def __init__(
968968
is_causal: bool = True,
969969
):
970970
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
971+
971972
environment.set_vllm_config(vllm_config)
973+
finalize_config()
974+
972975
self.is_driver_worker = is_driver_worker
973976
self.return_hidden_states = return_hidden_states
974977

0 commit comments

Comments
 (0)