diff --git a/ktransformers/models/custom_cache.py b/ktransformers/models/custom_cache.py index 350af73e..e9838a7f 100644 --- a/ktransformers/models/custom_cache.py +++ b/ktransformers/models/custom_cache.py @@ -35,9 +35,10 @@ class StaticCache(transformers.StaticCache): """ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device: torch.device| dict, dtype=None) -> None: - Cache.__init__(self) - self.max_batch_size = max_batch_size - self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len + #Cache.__init__(self) + super().__init__(config=config, layers=config.num_hidden_layers, max_batch_size=max_batch_size, max_cache_len=max_cache_len) + self._max_batch_size = max_batch_size + self._max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads if config.architectures[0] == "DeepseekV3ForCausalLM": self.head_dim = config.qk_rope_head_dim @@ -109,6 +110,11 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: self.value_cache.append(new_layer_value_cache) self.past_tokens.append(0) + def get_usable_length(self, seq_len=None, layer_idx=None): + if hasattr(self, "cache_position") and self.cache_position is not None: + return int(self.cache_position.max().item()) + 1 + return 0 + def update( self, key_states: torch.Tensor, @@ -330,4 +336,4 @@ def get_k_cache(self, layer_idx): return self.k_caches[layer_idx] def get_v_cache(self, layer_idx): - return self.v_caches[layer_idx] \ No newline at end of file + return self.v_caches[layer_idx] diff --git a/ktransformers/models/modeling_deepseek.py b/ktransformers/models/modeling_deepseek.py index f6845eca..f03a9628 100644 --- a/ktransformers/models/modeling_deepseek.py +++ b/ktransformers/models/modeling_deepseek.py @@ -62,6 +62,7 @@ replace_return_docstrings, ) from transformers.utils.import_utils import is_torch_fx_available +from transformers.generation import GenerationMixin from .configuration_deepseek import DeepseekV2Config import torch.distributed as dist import numpy as np @@ -1641,7 +1642,7 @@ def _update_causal_mask( return causal_mask -class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): +class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config):