Skip to content

Commit 6cac54f

Browse files
authored
[v1] Re-init input batch for multiple kv cache groups (vllm-project#18654)
Signed-off-by: Chen Zhang <[email protected]>
1 parent 6865fe0 commit 6cac54f

File tree

6 files changed

+61
-46
lines changed

6 files changed

+61
-46
lines changed

tests/v1/worker/test_gpu_input_batch.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
from vllm.sampling_params import SamplingParams
1212
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
13-
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
14-
KVCacheGroupSpec, KVCacheTensor)
1513
from vllm.v1.sample.metadata import SamplingMetadata
1614
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
1715
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@@ -25,27 +23,6 @@
2523
MAX_NUM_PROMPT_TOKENS = 64
2624

2725

28-
def get_kv_cache_config() -> KVCacheConfig:
29-
return KVCacheConfig(
30-
num_blocks=10,
31-
tensors={
32-
"layer.0": KVCacheTensor(size=1024),
33-
},
34-
kv_cache_groups=[
35-
KVCacheGroupSpec(
36-
layer_names=["layer.0"],
37-
kv_cache_spec=FullAttentionSpec(
38-
block_size=1,
39-
num_kv_heads=1,
40-
head_size=16,
41-
dtype=torch.float16,
42-
use_mla=False,
43-
),
44-
),
45-
],
46-
)
47-
48-
4926
def _compare_objs(obj1, obj2):
5027
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
5128
attr_names = set([
@@ -252,7 +229,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
252229
device=torch.device(device),
253230
pin_memory=is_pin_memory_available(),
254231
vocab_size=1024,
255-
block_size=1,
232+
block_sizes=[1],
256233
)
257234
reqs: list[CachedRequestState] = []
258235
req_id_reqs = {}
@@ -342,7 +319,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
342319
device=torch.device(device),
343320
pin_memory=is_pin_memory_available(),
344321
vocab_size=1024,
345-
block_size=1,
322+
block_sizes=[1],
346323
)
347324
ref_input_batch: InputBatch = InputBatch(
348325
max_num_reqs=batch_size,
@@ -351,7 +328,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
351328
device=torch.device(device),
352329
pin_memory=is_pin_memory_available(),
353330
vocab_size=1024,
354-
block_size=1,
331+
block_sizes=[1],
355332
)
356333

357334
reqs: list[CachedRequestState] = []

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def initialize_kv_cache(runner: GPUModelRunner):
5454
device=runner.device,
5555
pin_memory=runner.pin_memory,
5656
vocab_size=runner.model_config.get_vocab_size(),
57-
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size,
57+
block_sizes=[
58+
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
59+
],
5860
)
5961
runner.initialize_attn_backend(kv_cache_config)
6062

vllm/v1/worker/block_table.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,11 @@ class MultiGroupBlockTable:
105105

106106
def __init__(self, max_num_reqs: int, max_model_len: int,
107107
max_num_batched_tokens: int, pin_memory: bool,
108-
device: torch.device, block_size: int) -> None:
108+
device: torch.device, block_sizes: list[int]) -> None:
109109
self.block_tables = [
110110
BlockTable(max_num_reqs, cdiv(max_model_len, block_size),
111111
max_num_batched_tokens, pin_memory, device)
112+
for block_size in block_sizes
112113
]
113114

114115
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:

vllm/v1/worker/gpu_input_batch.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,14 @@ def get_token_id(self, idx: int) -> int:
5656
class InputBatch:
5757

5858
def __init__(
59-
self,
60-
max_num_reqs: int,
61-
max_model_len: int,
62-
max_num_batched_tokens: int,
63-
device: torch.device,
64-
pin_memory: bool,
65-
vocab_size: int,
66-
block_size: int,
59+
self,
60+
max_num_reqs: int,
61+
max_model_len: int,
62+
max_num_batched_tokens: int,
63+
device: torch.device,
64+
pin_memory: bool,
65+
vocab_size: int,
66+
block_sizes: list[int], # The block_size of each kv cache group
6767
):
6868
self.max_num_reqs = max_num_reqs
6969
self.max_model_len = max_model_len
@@ -105,7 +105,7 @@ def __init__(
105105
max_num_batched_tokens=max_num_batched_tokens,
106106
pin_memory=pin_memory,
107107
device=device,
108-
block_size=block_size,
108+
block_sizes=block_sizes,
109109
)
110110

111111
# Sampling-related.

vllm/v1/worker/gpu_model_runner.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ def __init__(
143143
self.attn_metadata_builders: list[AttentionMetadataBuilder] = []
144144
self.attn_backends: list[type[AttentionBackend]] = []
145145
# self.kv_cache_config: KVCacheConfig
146-
# self.input_batch: InputBatch # Persistent batch.
147146

148147
# req_id -> (input_id -> encoder_output)
149148
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
@@ -173,14 +172,23 @@ def __init__(
173172
# Request states.
174173
self.requests: dict[str, CachedRequestState] = {}
175174

175+
# Input Batch
176+
# NOTE(Chen): Ideally, we should initialize the input batch inside
177+
# `initialize_kv_cache` based on the kv cache config. However, as in
178+
# https://github.com/vllm-project/vllm/pull/18298, due to some unknown
179+
# reasons, we have to initialize the input batch before `load_model`,
180+
# quantization + weight offloading will fail otherwise. As a temporary
181+
# solution, we initialize the input batch here, and re-initialize it
182+
# in `initialize_kv_cache` if the block_sizes here is different from
183+
# the block_sizes in the kv cache config.
176184
self.input_batch = InputBatch(
177185
max_num_reqs=self.max_num_reqs,
178186
max_model_len=self.max_model_len,
179187
max_num_batched_tokens=self.max_num_tokens,
180188
device=self.device,
181189
pin_memory=self.pin_memory,
182190
vocab_size=self.model_config.get_vocab_size(),
183-
block_size=self.cache_config.block_size,
191+
block_sizes=[self.cache_config.block_size],
184192
)
185193

186194
self.use_cuda_graph = (self.vllm_config.compilation_config.level
@@ -2040,18 +2048,44 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
20402048
self.attn_backends.append(attn_backend_i)
20412049
self.attn_metadata_builders.append(attn_metadata_builder_i)
20422050

2051+
def may_reinitialize_input_batch(self,
2052+
kv_cache_config: KVCacheConfig) -> None:
2053+
"""
2054+
Re-initialize the input batch if the block sizes are different from
2055+
`[self.cache_config.block_size]`. This usually happens when there
2056+
are multiple KV cache groups.
2057+
2058+
Args:
2059+
kv_cache_config: The KV cache configuration.
2060+
"""
2061+
block_sizes = [
2062+
kv_cache_group.kv_cache_spec.block_size
2063+
for kv_cache_group in kv_cache_config.kv_cache_groups
2064+
]
2065+
if block_sizes != [self.cache_config.block_size]:
2066+
assert self.cache_config.cpu_offload_gb == 0, (
2067+
"Cannot re-initialize the input batch when CPU weight "
2068+
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
2069+
"for more details.")
2070+
self.input_batch = InputBatch(
2071+
max_num_reqs=self.max_num_reqs,
2072+
max_model_len=self.max_model_len,
2073+
max_num_batched_tokens=self.max_num_tokens,
2074+
device=self.device,
2075+
pin_memory=self.pin_memory,
2076+
vocab_size=self.model_config.get_vocab_size(),
2077+
block_sizes=block_sizes,
2078+
)
2079+
20432080
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
20442081
"""
20452082
Initialize KV cache based on `kv_cache_config`.
20462083
Args:
20472084
kv_cache_config: Configuration for the KV cache, including the KV
20482085
cache size of each layer
20492086
"""
2050-
if len(kv_cache_config.kv_cache_groups) > 1:
2051-
raise NotImplementedError(
2052-
"Hybrid models with more than one KV cache type are not "
2053-
"supported yet.")
20542087
self.kv_cache_config = kv_cache_config
2088+
self.may_reinitialize_input_batch(kv_cache_config)
20552089
self.initialize_attn_backend(kv_cache_config)
20562090

20572091
kv_caches: dict[str, torch.Tensor] = {}

vllm/v1/worker/tpu_model_runner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def __init__(
200200
device=self.device,
201201
pin_memory=self.pin_memory,
202202
vocab_size=self.model_config.get_vocab_size(),
203-
block_size=self.block_size,
203+
block_sizes=[self.block_size],
204204
)
205205

206206
# Cached torch/numpy tensor
@@ -1358,8 +1358,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
13581358
device=self.device,
13591359
pin_memory=self.pin_memory,
13601360
vocab_size=self.model_config.get_vocab_size(),
1361-
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.
1362-
block_size,
1361+
block_sizes=[
1362+
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
1363+
],
13631364
)
13641365
# Verify dtype compatibility between block_table_cpu and input_batch
13651366
assert self.block_table_cpu.dtype == self.input_batch.block_table[

0 commit comments

Comments
 (0)