Skip to content

Commit 67d67fb

Browse files
committed
Refactor bloom modeling and add tests
Signed-off-by: char-1ee <[email protected]>
1 parent 7f9f667 commit 67d67fb

File tree

5 files changed

+354
-178
lines changed

5 files changed

+354
-178
lines changed

colossalai/inference/kv_cache/kvcache_manager.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Tuple
1+
from typing import Any, List, Tuple
22

33
import torch
44
from transformers.configuration_utils import PretrainedConfig
@@ -15,9 +15,11 @@
1515
GIGABYTE = 1024**3
1616

1717

18-
def get_model_config_attr(config: PretrainedConfig, attr_name: str):
18+
def get_model_config_attr(config: PretrainedConfig, attr_name: str, alter_attr: Any = None):
1919
if hasattr(config, attr_name):
2020
return getattr(config, attr_name)
21+
if alter_attr is not None: # TODO, rebase caidi changes
22+
return alter_attr
2123
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map[attr_name]):
2224
return getattr(config, config.attribute_map[attr_name])
2325
raise AttributeError(f"{attr_name} is not found in config")
@@ -53,7 +55,12 @@ class KVCacheManager:
5355
And it's possible to have a batch of sequences with different lengths of block tables.
5456
"""
5557

56-
def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False) -> None:
58+
def __init__(
59+
self,
60+
config: InferenceConfig,
61+
model_config: PretrainedConfig,
62+
verbose: bool = False,
63+
) -> None:
5764
self.logger = get_dist_logger(__name__)
5865
self.device = get_current_device()
5966

@@ -64,14 +71,15 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
6471
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
6572
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
6673
self.head_num = get_model_config_attr(model_config, "num_attention_heads")
74+
self.kv_head_num = get_model_config_attr(model_config, "num_key_value_heads", alter_attr=self.head_num)
6775
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num
6876

69-
if hasattr(config, "num_key_value_heads"):
70-
self.kv_head_num = getattr(config, "num_key_value_heads")
71-
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map["num_key_value_heads"]):
72-
self.kv_head_num = getattr(config, config.attribute_map["num_key_value_heads"])
73-
else:
74-
self.kv_head_num = self.head_num
77+
# if hasattr(config, "num_key_value_heads"):
78+
# self.kv_head_num = getattr(config, "num_key_value_heads")
79+
# elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map["num_key_value_heads"]):
80+
# self.kv_head_num = getattr(config, config.attribute_map["num_key_value_heads"])
81+
# else:
82+
# self.kv_head_num = self.head_num
7583

7684
assert (
7785
self.kv_head_num % self.tp_size == 0
@@ -211,7 +219,8 @@ def allocate_context_from_block_table(self, block_table: torch.Tensor, context_l
211219
block.add_ref()
212220
if block_id == block_indexes[-1].item():
213221
self._allocate_on_block(
214-
block, block.block_size if context_len % block.block_size == 0 else context_len % block.block_size
222+
block,
223+
(block.block_size if context_len % block.block_size == 0 else context_len % block.block_size),
215224
)
216225
else:
217226
self._allocate_on_block(block, block.block_size)
@@ -278,9 +287,11 @@ def allocate_context_from_block_tables(self, block_tables: torch.Tensor, context
278287
block.add_ref()
279288
self._allocate_on_block(
280289
block,
281-
block.block_size
282-
if context_lengths[i] % block.block_size == 0
283-
else context_lengths[i].item() % block.block_size,
290+
(
291+
block.block_size
292+
if context_lengths[i] % block.block_size == 0
293+
else context_lengths[i].item() % block.block_size
294+
),
284295
)
285296
for block_id in alloc_block_ids:
286297
if block_id in alloc_block_ids[last_block_locs]:
@@ -453,7 +464,10 @@ def clear_all(self) -> None:
453464

454465
def get_physical_cache(self, layer_id: int, block_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
455466
"""Get the tensor corresponding to the cache block with the prompted id for a specific layer."""
456-
return self._kv_caches[0][layer_id][block_idx], self._kv_caches[1][layer_id][block_idx]
467+
return (
468+
self._kv_caches[0][layer_id][block_idx],
469+
self._kv_caches[1][layer_id][block_idx],
470+
)
457471

458472
def _allocate_on_block(self, block: CacheBlock, space_asked: int) -> int:
459473
"""Allocate a specific size of space on a provided cache block.

0 commit comments

Comments
 (0)