1+ import copy
12import logging
23from abc import ABC , abstractmethod
34from collections import defaultdict , deque
@@ -239,7 +240,7 @@ def __str__(self):
239240
240241 def __init__ (
241242 self ,
242- config : ModelArgs ,
243+ config_or_model : Union [ ModelArgs , nn . Module ] ,
243244 input_len : int ,
244245 cache_lens : Union [int , List [int ]],
245246 batch_size : int = 1 ,
@@ -248,8 +249,10 @@ def __init__(
248249 mask_val : float = float ("-inf" ),
249250 ):
250251 if isinstance (cache_lens , int ):
251- cache_lens = [cache_lens ] * config .n_layers
252- assert len (cache_lens ) == config .n_layers
252+ cache_lens_dict = defaultdict (lambda x = cache_lens : x )
253+ cache_lens = [cache_lens ]
254+ else :
255+ cache_lens_dict = dict (enumerate (cache_lens ))
253256
254257 self ._masks = {
255258 cl : StaticAttentionMask (
@@ -258,6 +261,24 @@ def __init__(
258261 for cl in set (cache_lens )
259262 }
260263
264+ if isinstance (config_or_model , ModelArgs ):
265+ self ._from_config (config_or_model , cache_lens_dict , batch_size , dtype )
266+ else :
267+ self ._from_model (config_or_model , cache_lens_dict , batch_size , dtype )
268+
269+ self .input_len = input_len
270+ self .style = style
271+ self .mask_val = mask_val
272+ self .pos = 0
273+ self .cache_full = False
274+
275+ def _from_config (
276+ self ,
277+ config : ModelArgs ,
278+ cache_lens : Dict [int , int ],
279+ batch_size : int ,
280+ dtype : torch .dtype ,
281+ ):
261282 rope = Rope (config )
262283 freqs = rope .get_freqs (None , config .max_context_len )
263284 self .freqs_cos = freqs [0 ].to (dtype )
@@ -311,13 +332,63 @@ def __init__(
311332 if cache_lens [layer_id ] > 0
312333 }
313334
314- self .config = config
315- self .input_len = input_len
316- self .cache_lens = cache_lens
317- self .style = style
318- self .mask_val = mask_val
319- self .pos = 0
320- self .cache_full = False
335+ self .generate_full_logits = config .generate_full_logits
336+
337+ def _from_model (
338+ self ,
339+ config : nn .Module ,
340+ cache_lens : Dict [int , int ],
341+ batch_size : int ,
342+ dtype : torch .dtype ,
343+ ):
344+ static_attentions = []
345+ for module in config .modules ():
346+ if isinstance (module , StaticAttention ):
347+ static_attentions .append (module )
348+
349+ if not static_attentions :
350+ raise ValueError ("No StaticAttention modules found in the provided module" )
351+
352+ config = copy .copy (static_attentions [0 ].rope .config )
353+ config .use_hf_rope = static_attentions [0 ].rope .use_hf_rope
354+ rope = Rope (config )
355+ freqs = rope .get_freqs (None , config .max_context_len )
356+ self .freqs_cos = freqs [0 ].to (dtype )
357+ self .freqs_sin = freqs [1 ].to (dtype )
358+
359+ self .k_caches = {}
360+ self .v_caches = {}
361+ for attn in static_attentions :
362+ if attn .split_mha :
363+ for head_id in range (attn .n_heads ):
364+ cache_key = StaticKVCache .calculate_cache_key (
365+ attn .layer_id , head_id
366+ )
367+ for cache in (self .k_caches , self .v_caches ):
368+ assert (
369+ cache_key not in cache
370+ ), "Found StaticAttention modules with duplicated layer_id"
371+ cache [cache_key ] = torch .zeros (
372+ batch_size ,
373+ cache_lens [attn .layer_id ],
374+ attn .head_dim ,
375+ dtype = dtype ,
376+ )
377+ else :
378+ cache_key = StaticKVCache .calculate_cache_key (attn .layer_id , 0 )
379+ for cache in (self .k_caches , self .v_caches ):
380+ assert (
381+ cache_key not in cache
382+ ), "Found StaticAttention modules with duplicated layer_id"
383+ cache [cache_key ] = torch .zeros (
384+ batch_size ,
385+ attn .n_kv_heads ,
386+ cache_lens [attn .layer_id ],
387+ attn .head_dim ,
388+ dtype = dtype ,
389+ )
390+
391+ self .generate_full_logits = True
321392
322393 @property
323394 def masks (self ):
@@ -352,13 +423,13 @@ def prefill(
352423 all_logits = None
353424 for i in range (0 , tokens .size (1 ), self .input_len ):
354425 logits = self ._run_once (model , tokens [:, i : i + self .input_len ])[0 ]
355- if self .config . generate_full_logits :
426+ if self .generate_full_logits :
356427 if all_logits is None :
357428 all_logits = logits
358429 else :
359430 all_logits = torch .cat ([all_logits , logits ], dim = 1 )
360431
361- if self .config . generate_full_logits :
432+ if self .generate_full_logits :
362433 return all_logits [:, : tokens .size (1 ), :]
363434
364435 return logits
@@ -637,9 +708,10 @@ def _get_lookahead_position_offsets(
637708
638709
639710class _Rope (nn .Module ):
640- def __init__ (self , use_hf_rope ):
711+ def __init__ (self , config : ModelArgs ):
641712 super ().__init__ ()
642- self .use_hf_rope = use_hf_rope
713+ self .config = config
714+ self .use_hf_rope = config .use_hf_rope
643715
644716 def forward (
645717 self , x : torch .Tensor , freqs_cos : torch .Tensor , freqs_sin : torch .Tensor
@@ -755,7 +827,8 @@ def __init__(
755827 self .v_caches = nn .ModuleList ([StaticVCache (layer_id , 0 )])
756828
757829 self .wo = nn .Linear (self .n_heads * self .head_dim , self .dim , bias = False )
758- self .rope = _Rope (rope .params .use_hf_rope )
830+ self .rope = _Rope (rope .params )
831+ self .layer_id = layer_id
759832
760833 if self .use_qk_norm :
761834 self .q_norm = torch .nn .RMSNorm (self .head_dim , config .norm_eps )
0 commit comments