1
- from typing import List , Tuple
1
+ from typing import Any , List , Tuple
2
2
3
3
import torch
4
4
from transformers .configuration_utils import PretrainedConfig
15
15
GIGABYTE = 1024 ** 3
16
16
17
17
18
- def get_model_config_attr (config : PretrainedConfig , attr_name : str ):
18
+ def get_model_config_attr (config : PretrainedConfig , attr_name : str , alter_attr : Any = None ):
19
19
if hasattr (config , attr_name ):
20
20
return getattr (config , attr_name )
21
+ if alter_attr is not None : # TODO, rebase caidi changes
22
+ return alter_attr
21
23
elif hasattr (config , "attribute_map" ) and hasattr (config , config .attribute_map [attr_name ]):
22
24
return getattr (config , config .attribute_map [attr_name ])
23
25
raise AttributeError (f"{ attr_name } is not found in config" )
@@ -53,7 +55,12 @@ class KVCacheManager:
53
55
And it's possible to have a batch of sequences with different lengths of block tables.
54
56
"""
55
57
56
- def __init__ (self , config : InferenceConfig , model_config : PretrainedConfig , verbose : bool = False ) -> None :
58
+ def __init__ (
59
+ self ,
60
+ config : InferenceConfig ,
61
+ model_config : PretrainedConfig ,
62
+ verbose : bool = False ,
63
+ ) -> None :
57
64
self .logger = get_dist_logger (__name__ )
58
65
self .device = get_current_device ()
59
66
@@ -64,14 +71,15 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
64
71
self .elem_size_in_bytes = torch .tensor ([], dtype = self .dtype ).element_size ()
65
72
self .num_layers = get_model_config_attr (model_config , "num_hidden_layers" )
66
73
self .head_num = get_model_config_attr (model_config , "num_attention_heads" )
74
+ self .kv_head_num = get_model_config_attr (model_config , "num_key_value_heads" , alter_attr = self .head_num )
67
75
self .head_size = get_model_config_attr (model_config , "hidden_size" ) // self .head_num
68
76
69
- if hasattr (config , "num_key_value_heads" ):
70
- self .kv_head_num = getattr (config , "num_key_value_heads" )
71
- elif hasattr (config , "attribute_map" ) and hasattr (config , config .attribute_map ["num_key_value_heads" ]):
72
- self .kv_head_num = getattr (config , config .attribute_map ["num_key_value_heads" ])
73
- else :
74
- self .kv_head_num = self .head_num
77
+ # if hasattr(config, "num_key_value_heads"):
78
+ # self.kv_head_num = getattr(config, "num_key_value_heads")
79
+ # elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map["num_key_value_heads"]):
80
+ # self.kv_head_num = getattr(config, config.attribute_map["num_key_value_heads"])
81
+ # else:
82
+ # self.kv_head_num = self.head_num
75
83
76
84
assert (
77
85
self .kv_head_num % self .tp_size == 0
@@ -90,7 +98,12 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
90
98
self .num_blocks = self .max_blocks_per_sequence * self .max_batch_size * self .beam_width
91
99
92
100
# Physical cache allocation
93
- alloc_shape = (self .num_blocks , self .kv_head_num , self .block_size , self .head_size )
101
+ alloc_shape = (
102
+ self .num_blocks ,
103
+ self .kv_head_num ,
104
+ self .block_size ,
105
+ self .head_size ,
106
+ )
94
107
self .logger .info (f"Allocating KV cache with shape: { alloc_shape } consisting of { self .num_blocks } blocks." )
95
108
self ._kv_caches = self ._init_device_caches (alloc_shape )
96
109
self .total_physical_cache_size_in_bytes = (
@@ -202,7 +215,8 @@ def allocate_context_from_block_table(self, block_table: torch.Tensor, context_l
202
215
block .add_ref ()
203
216
if block_id == block_indexes [- 1 ].item ():
204
217
self ._allocate_on_block (
205
- block , block .block_size if context_len % block .block_size == 0 else context_len % block .block_size
218
+ block ,
219
+ (block .block_size if context_len % block .block_size == 0 else context_len % block .block_size ),
206
220
)
207
221
else :
208
222
self ._allocate_on_block (block , block .block_size )
@@ -269,9 +283,11 @@ def allocate_context_from_block_tables(self, block_tables: torch.Tensor, context
269
283
block .add_ref ()
270
284
self ._allocate_on_block (
271
285
block ,
272
- block .block_size
273
- if context_lengths [i ] % block .block_size == 0
274
- else context_lengths [i ].item () % block .block_size ,
286
+ (
287
+ block .block_size
288
+ if context_lengths [i ] % block .block_size == 0
289
+ else context_lengths [i ].item () % block .block_size
290
+ ),
275
291
)
276
292
for block_id in alloc_block_ids :
277
293
if block_id in alloc_block_ids [last_block_locs ]:
@@ -444,7 +460,10 @@ def clear_all(self) -> None:
444
460
445
461
def get_physical_cache (self , layer_id : int , block_idx : int ) -> Tuple [torch .Tensor , torch .Tensor ]:
446
462
"""Get the tensor corresponding to the cache block with the prompted id for a specific layer."""
447
- return self ._kv_caches [0 ][layer_id ][block_idx ], self ._kv_caches [1 ][layer_id ][block_idx ]
463
+ return (
464
+ self ._kv_caches [0 ][layer_id ][block_idx ],
465
+ self ._kv_caches [1 ][layer_id ][block_idx ],
466
+ )
448
467
449
468
def _allocate_on_block (self , block : CacheBlock , space_asked : int ) -> int :
450
469
"""Allocate a specific size of space on a provided cache block.
0 commit comments