@@ -379,10 +379,52 @@ def build_kv_cache(self, messages) -> DynamicCache:
379379 raise ValueError (
380380 "Prompt after chat template is empty, cannot build KV cache. Check your messages input."
381381 )
382- kv = DynamicCache ()
382+ # Create cache and perform forward pass without pre-existing cache
383383 with torch .no_grad ():
384- self .model (** inputs , use_cache = True , past_key_values = kv )
385- for i , (k , v ) in enumerate (zip (kv .key_cache , kv .value_cache , strict = False )):
386- kv .key_cache [i ] = k [:, :, :seq_len , :]
387- kv .value_cache [i ] = v [:, :, :seq_len , :]
388- return kv
384+ outputs = self .model (** inputs , use_cache = True )
385+
386+ # Get the cache from model outputs
387+ if hasattr (outputs , "past_key_values" ) and outputs .past_key_values is not None :
388+ kv = outputs .past_key_values
389+
390+ # Convert from legacy tuple format to DynamicCache if needed
391+ if isinstance (kv , tuple ):
392+ kv = DynamicCache .from_legacy_cache (kv )
393+
394+ # Handle compatibility between old and new transformers versions
395+ # In newer versions, DynamicCache uses 'layers' attribute
396+ # In older versions, it uses 'key_cache' and 'value_cache' attributes
397+ if hasattr (kv , "layers" ):
398+ # New version: trim cache using layers attribute
399+ for layer in kv .layers :
400+ if hasattr (layer , "key_cache" ) and hasattr (layer , "value_cache" ):
401+ # Trim each layer's cache to the sequence length
402+ if layer .key_cache is not None :
403+ layer .key_cache = layer .key_cache [:, :, :seq_len , :]
404+ if layer .value_cache is not None :
405+ layer .value_cache = layer .value_cache [:, :, :seq_len , :]
406+ elif hasattr (layer , "keys" ) and hasattr (layer , "values" ):
407+ # Alternative attribute names in some versions
408+ if layer .keys is not None :
409+ layer .keys = layer .keys [:, :, :seq_len , :]
410+ if layer .values is not None :
411+ layer .values = layer .values [:, :, :seq_len , :]
412+ elif hasattr (kv , "key_cache" ) and hasattr (kv , "value_cache" ):
413+ # Old version: trim cache using key_cache and value_cache attributes
414+ for i in range (len (kv .key_cache )):
415+ if kv .key_cache [i ] is not None :
416+ kv .key_cache [i ] = kv .key_cache [i ][:, :, :seq_len , :]
417+ if kv .value_cache [i ] is not None :
418+ kv .value_cache [i ] = kv .value_cache [i ][:, :, :seq_len , :]
419+ else :
420+ # Fallback: log warning but continue without trimming
421+ logger .warning (
422+ f"DynamicCache object of type { type (kv )} has unexpected structure. "
423+ f"Cache trimming skipped. Available attributes: { dir (kv )} "
424+ )
425+
426+ return kv
427+ else :
428+ raise RuntimeError (
429+ "Failed to build KV cache: no cache data available from model outputs"
430+ )
0 commit comments