Skip to content

Conversation

@ggerganov
Copy link
Member

fix #10380

Disallow context shifts for models that do not support it (such as DeepSeek V2). Add

bool llama_kv_cache_can_shift(struct llama_context * ctx);

@ggerganov ggerganov merged commit 8e752a7 into master Nov 19, 2024
1 check passed
@ggerganov ggerganov deleted the gg/llama-can-shift branch November 19, 2024 11:29
Comment on lines +20465 to +20467
bool llama_kv_cache_can_shift(struct llama_context * ctx) {
return ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this return false for recurrent models as well? Not sure what's the logic there, but llama_kv_cache_update_internal silently ignores models with LLAMA_ROPE_TYPE_NONE.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's likely needed to return false for recurrent models.

The reason to do nothing in llama_kv_cache_update_internal when the rope type is none is because when we apply shifts to the KV cache using functions like llama_kv_cache_seq_add(), we do 2 things:

  • Update the positions of the KV cells - i.e. just modify the meta data in llama_kv_cell
  • Re-rope the data in the KV cells

The later step is necessary only if the data is roped. For ALiBi models for example, we should not apply this second step, but in theory we still support "shifting" the KV cache for those models, since the positional information is in the KQ mask.

arthw pushed a commit to arthw/llama.cpp that referenced this pull request Dec 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Deepseek2 does not support K-shift Denial-of-Service vulnerability

4 participants