Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions ktransformers/models/custom_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ class StaticCache(transformers.StaticCache):
"""

def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device: torch.device| dict, dtype=None) -> None:
Cache.__init__(self)
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
#Cache.__init__(self)
super().__init__(config=config, layers=config.num_hidden_layers, max_batch_size=max_batch_size, max_cache_len=max_cache_len)
self._max_batch_size = max_batch_size
self._max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
if config.architectures[0] == "DeepseekV3ForCausalLM":
self.head_dim = config.qk_rope_head_dim
Expand Down Expand Up @@ -109,6 +110,11 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
self.value_cache.append(new_layer_value_cache)
self.past_tokens.append(0)

def get_usable_length(self, seq_len=None, layer_idx=None):
if hasattr(self, "cache_position") and self.cache_position is not None:
return int(self.cache_position.max().item()) + 1
return 0

def update(
self,
key_states: torch.Tensor,
Expand Down Expand Up @@ -330,4 +336,4 @@ def get_k_cache(self, layer_idx):
return self.k_caches[layer_idx]

def get_v_cache(self, layer_idx):
return self.v_caches[layer_idx]
return self.v_caches[layer_idx]
3 changes: 2 additions & 1 deletion ktransformers/models/modeling_deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
replace_return_docstrings,
)
from transformers.utils.import_utils import is_torch_fx_available
from transformers.generation import GenerationMixin
from .configuration_deepseek import DeepseekV2Config
import torch.distributed as dist
import numpy as np
Expand Down Expand Up @@ -1641,7 +1642,7 @@ def _update_causal_mask(
return causal_mask


class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

def __init__(self, config):
Expand Down