-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[KVCache] Enable sliding window for ragged prefill (SelfAttention)
#18630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
f3d523b
0dee82d
8aaef69
30c14d6
df72587
4372e60
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -101,6 +101,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { | |||||||||||||||||||
| const bool support_sliding_window_; | ||||||||||||||||||||
| /*! \brief A boolean flag indicating if the KV cache has per layer sliding window. */ | ||||||||||||||||||||
| const bool support_layer_sliding_window_; | ||||||||||||||||||||
| /*! \brief The sliding window size for sliding window attention. */ | ||||||||||||||||||||
| int32_t sliding_window_size_; | ||||||||||||||||||||
| /*! \brief The attention kinds for each layer. */ | ||||||||||||||||||||
| const std::vector<AttnKind> attn_kinds_; | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -314,6 +316,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { | |||||||||||||||||||
| : support_sliding_window), | ||||||||||||||||||||
| support_layer_sliding_window_(std::find(attn_kinds.begin(), attn_kinds.end(), | ||||||||||||||||||||
| AttnKind::kMHASliding) != attn_kinds.end()), | ||||||||||||||||||||
| sliding_window_size_(-1), | ||||||||||||||||||||
| attn_kinds_(std::move(attn_kinds)), | ||||||||||||||||||||
| rope_mode_(support_sliding_window && rope_mode != RoPEMode::kNone ? RoPEMode::kInline | ||||||||||||||||||||
| : rope_mode), | ||||||||||||||||||||
|
|
@@ -766,6 +769,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { | |||||||||||||||||||
| // introduce more sink. Therefore, we update the given attn sink size. | ||||||||||||||||||||
| it->second.last_block_attn_sink_size = std::max(attn_sink_size - prefix_length, 0); | ||||||||||||||||||||
| it->second.sliding_window_size = sliding_window_size; | ||||||||||||||||||||
| if (sliding_window_size_ == -1) | ||||||||||||||||||||
| sliding_window_size_ = sliding_window_size; | ||||||||||||||||||||
|
||||||||||||||||||||
| if (sliding_window_size_ == -1) | |
| sliding_window_size_ = sliding_window_size; | |
| if (sliding_window_size_ == -1) { | |
| sliding_window_size_ = sliding_window_size; | |
| } else { | |
| ICHECK_EQ(sliding_window_size_, sliding_window_size) | |
| << "Inconsistent sliding window sizes are not supported. Previously got " | |
| << sliding_window_size_ << ", but now got " << sliding_window_size; | |
| } |
Uh oh!
There was an error while loading. Please reload this page.