1
- from typing import List , Tuple
1
+ from typing import Any , List , Tuple
2
2
3
3
import torch
4
4
from transformers .configuration_utils import PretrainedConfig
15
15
GIGABYTE = 1024 ** 3
16
16
17
17
18
- def get_model_config_attr (config : PretrainedConfig , attr_name : str ):
18
+ def get_model_config_attr (config : PretrainedConfig , attr_name : str , alter_attr : Any = None ):
19
19
if hasattr (config , attr_name ):
20
20
return getattr (config , attr_name )
21
+ if alter_attr is not None : # TODO, rebase caidi changes
22
+ return alter_attr
21
23
elif hasattr (config , "attribute_map" ) and hasattr (config , config .attribute_map [attr_name ]):
22
24
return getattr (config , config .attribute_map [attr_name ])
23
25
raise AttributeError (f"{ attr_name } is not found in config" )
@@ -53,7 +55,12 @@ class KVCacheManager:
53
55
And it's possible to have a batch of sequences with different lengths of block tables.
54
56
"""
55
57
56
- def __init__ (self , config : InferenceConfig , model_config : PretrainedConfig , verbose : bool = False ) -> None :
58
+ def __init__ (
59
+ self ,
60
+ config : InferenceConfig ,
61
+ model_config : PretrainedConfig ,
62
+ verbose : bool = False ,
63
+ ) -> None :
57
64
self .logger = get_dist_logger (__name__ )
58
65
self .device = get_current_device ()
59
66
@@ -64,14 +71,15 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
64
71
self .elem_size_in_bytes = torch .tensor ([], dtype = self .dtype ).element_size ()
65
72
self .num_layers = get_model_config_attr (model_config , "num_hidden_layers" )
66
73
self .head_num = get_model_config_attr (model_config , "num_attention_heads" )
74
+ self .kv_head_num = get_model_config_attr (model_config , "num_key_value_heads" , alter_attr = self .head_num )
67
75
self .head_size = get_model_config_attr (model_config , "hidden_size" ) // self .head_num
68
76
69
- if hasattr (config , "num_key_value_heads" ):
70
- self .kv_head_num = getattr (config , "num_key_value_heads" )
71
- elif hasattr (config , "attribute_map" ) and hasattr (config , config .attribute_map ["num_key_value_heads" ]):
72
- self .kv_head_num = getattr (config , config .attribute_map ["num_key_value_heads" ])
73
- else :
74
- self .kv_head_num = self .head_num
77
+ # if hasattr(config, "num_key_value_heads"):
78
+ # self.kv_head_num = getattr(config, "num_key_value_heads")
79
+ # elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map["num_key_value_heads"]):
80
+ # self.kv_head_num = getattr(config, config.attribute_map["num_key_value_heads"])
81
+ # else:
82
+ # self.kv_head_num = self.head_num
75
83
76
84
assert (
77
85
self .kv_head_num % self .tp_size == 0
@@ -211,7 +219,8 @@ def allocate_context_from_block_table(self, block_table: torch.Tensor, context_l
211
219
block .add_ref ()
212
220
if block_id == block_indexes [- 1 ].item ():
213
221
self ._allocate_on_block (
214
- block , block .block_size if context_len % block .block_size == 0 else context_len % block .block_size
222
+ block ,
223
+ (block .block_size if context_len % block .block_size == 0 else context_len % block .block_size ),
215
224
)
216
225
else :
217
226
self ._allocate_on_block (block , block .block_size )
@@ -278,9 +287,11 @@ def allocate_context_from_block_tables(self, block_tables: torch.Tensor, context
278
287
block .add_ref ()
279
288
self ._allocate_on_block (
280
289
block ,
281
- block .block_size
282
- if context_lengths [i ] % block .block_size == 0
283
- else context_lengths [i ].item () % block .block_size ,
290
+ (
291
+ block .block_size
292
+ if context_lengths [i ] % block .block_size == 0
293
+ else context_lengths [i ].item () % block .block_size
294
+ ),
284
295
)
285
296
for block_id in alloc_block_ids :
286
297
if block_id in alloc_block_ids [last_block_locs ]:
@@ -453,7 +464,10 @@ def clear_all(self) -> None:
453
464
454
465
def get_physical_cache (self , layer_id : int , block_idx : int ) -> Tuple [torch .Tensor , torch .Tensor ]:
455
466
"""Get the tensor corresponding to the cache block with the prompted id for a specific layer."""
456
- return self ._kv_caches [0 ][layer_id ][block_idx ], self ._kv_caches [1 ][layer_id ][block_idx ]
467
+ return (
468
+ self ._kv_caches [0 ][layer_id ][block_idx ],
469
+ self ._kv_caches [1 ][layer_id ][block_idx ],
470
+ )
457
471
458
472
def _allocate_on_block (self , block : CacheBlock , space_asked : int ) -> int :
459
473
"""Allocate a specific size of space on a provided cache block.
0 commit comments