Skip to content

Commit 0de7a6b

Browse files
committed
Refactor bloom modeling and add tests
Signed-off-by: char-1ee <[email protected]>
1 parent 939dd61 commit 0de7a6b

File tree

5 files changed

+360
-179
lines changed

5 files changed

+360
-179
lines changed

colossalai/inference/kv_cache/kvcache_manager.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Tuple
1+
from typing import Any, List, Tuple
22

33
import torch
44
from transformers.configuration_utils import PretrainedConfig
@@ -15,9 +15,11 @@
1515
GIGABYTE = 1024**3
1616

1717

18-
def get_model_config_attr(config: PretrainedConfig, attr_name: str):
18+
def get_model_config_attr(config: PretrainedConfig, attr_name: str, alter_attr: Any = None):
1919
if hasattr(config, attr_name):
2020
return getattr(config, attr_name)
21+
if alter_attr is not None: # TODO, rebase caidi changes
22+
return alter_attr
2123
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map[attr_name]):
2224
return getattr(config, config.attribute_map[attr_name])
2325
raise AttributeError(f"{attr_name} is not found in config")
@@ -53,7 +55,12 @@ class KVCacheManager:
5355
And it's possible to have a batch of sequences with different lengths of block tables.
5456
"""
5557

56-
def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False) -> None:
58+
def __init__(
59+
self,
60+
config: InferenceConfig,
61+
model_config: PretrainedConfig,
62+
verbose: bool = False,
63+
) -> None:
5764
self.logger = get_dist_logger(__name__)
5865
self.device = get_current_device()
5966

@@ -64,14 +71,15 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
6471
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
6572
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
6673
self.head_num = get_model_config_attr(model_config, "num_attention_heads")
74+
self.kv_head_num = get_model_config_attr(model_config, "num_key_value_heads", alter_attr=self.head_num)
6775
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num
6876

69-
if hasattr(config, "num_key_value_heads"):
70-
self.kv_head_num = getattr(config, "num_key_value_heads")
71-
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map["num_key_value_heads"]):
72-
self.kv_head_num = getattr(config, config.attribute_map["num_key_value_heads"])
73-
else:
74-
self.kv_head_num = self.head_num
77+
# if hasattr(config, "num_key_value_heads"):
78+
# self.kv_head_num = getattr(config, "num_key_value_heads")
79+
# elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map["num_key_value_heads"]):
80+
# self.kv_head_num = getattr(config, config.attribute_map["num_key_value_heads"])
81+
# else:
82+
# self.kv_head_num = self.head_num
7583

7684
assert (
7785
self.kv_head_num % self.tp_size == 0
@@ -90,7 +98,12 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
9098
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
9199

92100
# Physical cache allocation
93-
alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
101+
alloc_shape = (
102+
self.num_blocks,
103+
self.kv_head_num,
104+
self.block_size,
105+
self.head_size,
106+
)
94107
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
95108
self._kv_caches = self._init_device_caches(alloc_shape)
96109
self.total_physical_cache_size_in_bytes = (
@@ -202,7 +215,8 @@ def allocate_context_from_block_table(self, block_table: torch.Tensor, context_l
202215
block.add_ref()
203216
if block_id == block_indexes[-1].item():
204217
self._allocate_on_block(
205-
block, block.block_size if context_len % block.block_size == 0 else context_len % block.block_size
218+
block,
219+
(block.block_size if context_len % block.block_size == 0 else context_len % block.block_size),
206220
)
207221
else:
208222
self._allocate_on_block(block, block.block_size)
@@ -269,9 +283,11 @@ def allocate_context_from_block_tables(self, block_tables: torch.Tensor, context
269283
block.add_ref()
270284
self._allocate_on_block(
271285
block,
272-
block.block_size
273-
if context_lengths[i] % block.block_size == 0
274-
else context_lengths[i].item() % block.block_size,
286+
(
287+
block.block_size
288+
if context_lengths[i] % block.block_size == 0
289+
else context_lengths[i].item() % block.block_size
290+
),
275291
)
276292
for block_id in alloc_block_ids:
277293
if block_id in alloc_block_ids[last_block_locs]:
@@ -444,7 +460,10 @@ def clear_all(self) -> None:
444460

445461
def get_physical_cache(self, layer_id: int, block_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
446462
"""Get the tensor corresponding to the cache block with the prompted id for a specific layer."""
447-
return self._kv_caches[0][layer_id][block_idx], self._kv_caches[1][layer_id][block_idx]
463+
return (
464+
self._kv_caches[0][layer_id][block_idx],
465+
self._kv_caches[1][layer_id][block_idx],
466+
)
448467

449468
def _allocate_on_block(self, block: CacheBlock, space_asked: int) -> int:
450469
"""Allocate a specific size of space on a provided cache block.

0 commit comments

Comments
 (0)