@@ -325,8 +325,9 @@ def __init__(self, sliding_window, *args, **kwargs):
325325 sliding_window (`int`):
326326 Effective window size: number of tokens that are kept on each update call.
327327 """
328- kwargs .pop ("max_cache_len" , None )
329- super ().__init__ (* args , max_cache_len = sliding_window , * args , ** kwargs )
328+ max_cache_len = kwargs .pop ("max_cache_len" , None )
329+ max_cache_len = min (sliding_window , max_cache_len ) if max_cache_len is not None else sliding_window
330+ super ().__init__ (* args , max_cache_len = max_cache_len , * args , ** kwargs )
330331
331332 def update (
332333 self ,
@@ -1277,9 +1278,7 @@ def max_batch_size(self) -> int:
12771278 def max_cache_len (self ) -> int :
12781279 """Return the maximum cache length of the cache"""
12791280 values = [layer .max_cache_len for layer in self .layers ]
1280- if len (set (values )) > 1 :
1281- raise ValueError (f"Max cache length is not consistent across layers: { values } " )
1282- return values [0 ]
1281+ return max (values )
12831282
12841283 @property
12851284 def is_compileable (self ) -> bool :
@@ -1655,7 +1654,7 @@ class QuantoQuantizedCache(QuantizedCache):
16551654 """
16561655
16571656 def __init__ (self , ** kwargs ) -> None :
1658- Cache .__init__ (self , cache_processor = QuantoQuantizedCacheProcessor , ** kwargs )
1657+ DynamicCache .__init__ (self , cache_processor = QuantoQuantizedCacheProcessor , ** kwargs )
16591658
16601659
16611660class HQQQuantizedCache (QuantizedCache ):
@@ -1697,7 +1696,7 @@ class HQQQuantizedCache(QuantizedCache):
16971696
16981697 def __init__ (self , backend = "HQQ" , ** kwargs ) -> None :
16991698 assert backend == "HQQ"
1700- Cache .__init__ (self , cache_processor = HQQQuantizedCacheProcessor , ** kwargs )
1699+ DynamicCache .__init__ (self , cache_processor = HQQQuantizedCacheProcessor , ** kwargs )
17011700
17021701
17031702class EncoderDecoderCache (Cache ):
@@ -1951,10 +1950,6 @@ def parse_layer_args_from_model_config(
19511950 )
19521951 # Adjust max_cache_len for sliding window layers (they can't be larger than sliding window)
19531952 max_cache_len = max_cache_len or config .max_position_embeddings
1954- if getattr (config , "sliding_window" , None ) is not None :
1955- sliding_window_len = min (config .sliding_window , max_cache_len )
1956- else :
1957- sliding_window_len = None
19581953 # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads:
19591954 head_dim = (
19601955 config .head_dim
@@ -1981,7 +1976,7 @@ def parse_layer_args_from_model_config(
19811976 "layer_device_map" : layer_device_map ,
19821977 "head_dim" : head_dim ,
19831978 "num_heads" : num_heads ,
1984- "sliding_window" : sliding_window_len ,
1979+ "sliding_window" : getattr ( config , "sliding_window" , None ) ,
19851980 }
19861981 return {k : v for k , v in layer_args .items () if v is not None }
19871982
0 commit comments