-
Notifications
You must be signed in to change notification settings - Fork 749
[DRAFT] Try make quantize kv cache work #6926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -142,13 +142,13 @@ def __init__(self, **kwargs): | |
|
|
||
| self.model_ = prune_output_vocab(self.model_, output_prune_map) | ||
|
|
||
| if self.use_kv_cache: | ||
| print("Setting up KV cache on the model...") | ||
| self.model_.setup_caches( | ||
| batch_size=1, | ||
| dtype=self.dtype, | ||
| decoder_max_seq_len=self.max_seq_len, | ||
| ) | ||
| # if self.use_kv_cache: | ||
| # print("Setting up KV cache on the model...") | ||
| # self.model_.setup_caches( | ||
| # batch_size=1, | ||
| # dtype=self.dtype, | ||
| # decoder_max_seq_len=self.max_seq_len, | ||
| # ) | ||
|
Comment on lines
+145
to
+151
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to do this because source transform happens after the model is set up, and we need to call the new swapped-in ET attention's setup_cache function. So we move the setup_cache to after the source transform |
||
|
|
||
| def get_eager_model(self) -> torch.nn.Module: | ||
| if self.dtype: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,13 @@ | |
|
|
||
| import torch | ||
| import torchtune.modules.attention as TorchTuneAttention | ||
| from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( | ||
| QuantizedKVCache, | ||
| ) | ||
| from executorch.examples.models.llama.source_transformation.sdpa import ( | ||
| SDPACustom, | ||
| SDPASimple, | ||
| ) | ||
| from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache | ||
| from torch import nn | ||
| from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention | ||
|
|
@@ -145,16 +152,27 @@ def __init__( | |
|
|
||
| # Use flex attention if supported and we are sample packing | ||
| self._attention_call = _sdpa_or_flex_attention() | ||
| self._sdpa = SDPA( | ||
| num_kv_heads=self.num_kv_heads, | ||
| num_heads=self.num_heads, | ||
| head_dim=self.head_dim, | ||
| attn_dropout=self.attn_dropout if self.training else 0.0, | ||
| is_causal=self.is_causal, | ||
| attention_fn=self._attention_call, | ||
| # self._sdpa = SDPA( | ||
| # num_kv_heads=self.num_kv_heads, | ||
| # num_heads=self.num_heads, | ||
| # head_dim=self.head_dim, | ||
| # attn_dropout=self.attn_dropout if self.training else 0.0, | ||
| # is_causal=self.is_causal, | ||
| # attention_fn=self._attention_call, | ||
| # kv_cache=self.kv_cache, | ||
| # ) | ||
|
|
||
| self._sdpa = SDPACustom( | ||
| kv_cache=self.kv_cache, | ||
| ) | ||
|
|
||
| # self._sdpa = SDPASimple( | ||
| # kv_cache=self.kv_cache, | ||
| # dim=self.embed_dim, | ||
| # head_dim=self.head_dim, | ||
| # n_rep=self.num_heads // self.num_kv_heads | ||
| # ) | ||
|
|
||
| # this flag indicates whether to update the kv-cache during forward | ||
| # passes. when disabled, we can have the cache setup but still | ||
| # perform normal forward passes | ||
|
|
@@ -177,13 +195,20 @@ def setup_cache( | |
| "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping." | ||
| ) | ||
| else: | ||
| self.kv_cache = InferenceKVCache( | ||
| batch_size=batch_size, | ||
| max_seq_len=max_seq_len, | ||
| num_kv_heads=self.num_kv_heads, | ||
| # self.kv_cache = InferenceKVCache( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you try adding |
||
| # batch_size=batch_size, | ||
| # max_seq_len=max_seq_len, | ||
| # num_kv_heads=self.num_kv_heads, | ||
| # head_dim=self.head_dim, | ||
| # dtype=dtype, | ||
| # transpose_cache=False, | ||
| # ) | ||
| self.kv_cache = QuantizedKVCache( | ||
| max_batch_size=batch_size, | ||
| max_seq_length=max_seq_len, | ||
| n_heads=self.num_kv_heads, | ||
| head_dim=self.head_dim, | ||
| dtype=dtype, | ||
| transpose_cache=False, | ||
| # dtype needs to be float32 atm, | ||
| ) | ||
| self._sdpa.kv_cache = self.kv_cache | ||
| self.cache_enabled = True | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setup_cache moved here