Skip to content

Commit 44ed8bb

Browse files
haozha111copybara-github
authored andcommitted
Unify KV Cache classes between standard layers and experimental layers. Implement a single KVCache and KVCacheEntry class that can handle different layout (regular or transposed).
PiperOrigin-RevId: 745342772
1 parent 67244ba commit 44ed8bb

File tree

14 files changed

+169
-368
lines changed

14 files changed

+169
-368
lines changed

ai_edge_torch/generative/examples/experimental/gemma/convert_gemma2_gpu_to_tflite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from absl import app
1919
from ai_edge_torch.generative.examples.experimental.gemma import gemma2_gpu
20-
from ai_edge_torch.generative.layers.experimental import kv_cache
20+
from ai_edge_torch.generative.layers import kv_cache
2121
from ai_edge_torch.generative.utilities import converter
2222
from ai_edge_torch.generative.utilities import export_config
2323
import torch
@@ -50,7 +50,7 @@ def _create_export_config(
5050
)
5151
decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
5252
export_config.decode_mask = decode_mask
53-
export_config.kvcache_cls = kv_cache.KVCacheTransposed
53+
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
5454
return export_config
5555

5656

ai_edge_torch/generative/examples/experimental/gemma/gemma2_gpu.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
from typing import List, Optional, Tuple
2626

2727
from ai_edge_torch.generative.layers import builder
28+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
2829
import ai_edge_torch.generative.layers.attention_utils as attn_utils
2930
from ai_edge_torch.generative.layers.experimental import attention
30-
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
3131
import ai_edge_torch.generative.layers.model_config as cfg
3232
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
3333
from ai_edge_torch.generative.utilities import export_config as export_cfg
@@ -75,8 +75,8 @@ def forward(
7575
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
7676
mask: Optional[torch.Tensor] = None,
7777
input_pos: Optional[torch.Tensor] = None,
78-
kv_cache: kv_utils.KVCacheEntryBase = None,
79-
) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntryBase]]:
78+
kv_cache: kv_utils.KVCacheEntry = None,
79+
) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
8080
"""Forward function of the Gemma2Block.
8181
8282
Exactly the same as TransformerBlock but we call the post-attention norm
@@ -87,7 +87,7 @@ def forward(
8787
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
8888
mask (torch.Tensor): the optional mask tensor.
8989
input_pos (torch.Tensor): the optional input position tensor.
90-
kv_cache (KVCacheEntryBase): the optional kv cache entry.
90+
kv_cache (KVCacheEntry): the optional kv cache entry.
9191
9292
Returns:
9393
output activation from this transformer block, and updated kv cache (if
@@ -151,10 +151,10 @@ def forward(
151151
self,
152152
tokens: torch.Tensor,
153153
input_pos: torch.Tensor,
154-
kv_cache: kv_utils.KVCacheBase,
154+
kv_cache: kv_utils.KVCache,
155155
mask: Optional[torch.Tensor] = None,
156156
export_config: Optional[export_cfg.ExportConfig] = None,
157-
) -> dict[torch.Tensor, kv_utils.KVCacheBase]:
157+
) -> dict[torch.Tensor, kv_utils.KVCache]:
158158
_, seq_len = tokens.size()
159159
assert self.config.max_seq_len >= seq_len, (
160160
f"Cannot forward sequence of length {seq_len}, max seq length is only"
@@ -185,9 +185,9 @@ def _forward_with_embeds(
185185
rope: Tuple[torch.Tensor, torch.Tensor],
186186
mask: torch.Tensor | List[torch.Tensor],
187187
input_pos: torch.Tensor,
188-
kv_cache: kv_utils.KVCacheBase,
188+
kv_cache: kv_utils.KVCache,
189189
export_config: Optional[export_cfg.ExportConfig] = None,
190-
) -> dict[torch.Tensor, kv_utils.KVCacheBase]:
190+
) -> dict[torch.Tensor, kv_utils.KVCache]:
191191
"""Forwards the model with input embeddings."""
192192
assert len(self.transformer_blocks) == len(kv_cache.caches), (
193193
"The number of transformer blocks and the number of KV cache entries"
@@ -204,7 +204,7 @@ def _forward_with_embeds(
204204
x, kv_entry = block(x, rope, mask_entry, input_pos, kv_entry)
205205
if kv_entry:
206206
updated_kv_entries.append(kv_entry)
207-
updated_kv_cache = kv_utils.KVCacheBase(tuple(updated_kv_entries))
207+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
208208

209209
if export_config is not None:
210210
if (

ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from absl import app
1919
from ai_edge_torch.generative.examples.gemma3 import gemma3
20-
from ai_edge_torch.generative.layers.experimental import kv_cache
20+
from ai_edge_torch.generative.layers import kv_cache
2121
from ai_edge_torch.generative.utilities import converter
2222
from ai_edge_torch.generative.utilities import export_config
2323
import torch
@@ -58,7 +58,7 @@ def _create_export_config(
5858
)
5959
decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
6060
export_config.decode_mask = decode_mask
61-
export_config.kvcache_cls = kv_cache.KVCacheTransposed
61+
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
6262
return export_config
6363

6464

ai_edge_torch/generative/examples/gemma3/decoder.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
from typing import List, Optional, Tuple
1919

2020
from ai_edge_torch.generative.layers import builder
21+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
2122
import ai_edge_torch.generative.layers.attention_utils as attn_utils
2223
from ai_edge_torch.generative.layers.experimental import attention
23-
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
2424
import ai_edge_torch.generative.layers.model_config as cfg
2525
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
2626
from ai_edge_torch.generative.utilities import export_config as export_cfg
@@ -81,8 +81,8 @@ def forward(
8181
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
8282
mask: Optional[torch.Tensor] = None,
8383
input_pos: Optional[torch.Tensor] = None,
84-
kv_cache: kv_utils.KVCacheEntryBase = None,
85-
) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntryBase]]:
84+
kv_cache: kv_utils.KVCacheEntry = None,
85+
) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
8686
"""Forward function of the Gemma3Block.
8787
8888
Exactly the same as TransformerBlock but we call the post-attention norm
@@ -241,13 +241,12 @@ def forward(
241241
self,
242242
tokens: torch.Tensor,
243243
input_pos: torch.Tensor,
244-
kv_cache: kv_utils.KVCacheBase,
244+
kv_cache: kv_utils.KVCache,
245245
input_embeds: Optional[torch.Tensor] = None,
246246
mask: Optional[torch.Tensor] = None,
247247
image_indices: Optional[torch.Tensor] = None,
248248
export_config: Optional[export_cfg.ExportConfig] = None,
249-
) -> dict[torch.Tensor, kv_utils.KVCacheBase]:
250-
249+
) -> dict[torch.Tensor, kv_utils.KVCache]:
251250
pixel_mask = None
252251
if input_embeds is None:
253252
# token embeddings of shape (b, t, n_embd)
@@ -287,10 +286,10 @@ def _forward_with_embeds(
287286
rope: List[Tuple[torch.Tensor, torch.Tensor]],
288287
mask: torch.Tensor | List[torch.Tensor],
289288
input_pos: torch.Tensor,
290-
kv_cache: kv_utils.KVCacheBase,
289+
kv_cache: kv_utils.KVCache,
291290
pixel_mask: Optional[torch.Tensor] = None,
292291
export_config: Optional[export_cfg.ExportConfig] = None,
293-
) -> dict[torch.Tensor, kv_utils.KVCacheBase]:
292+
) -> dict[torch.Tensor, kv_utils.KVCache]:
294293
"""Forwards the model with input embeddings."""
295294
assert len(self.transformer_blocks) == len(kv_cache.caches), (
296295
"The number of transformer blocks and the number of KV cache entries"
@@ -326,7 +325,7 @@ def _forward_with_embeds(
326325
x, kv_entry = block(x, rope[i], mask_entry, input_pos, kv_entry)
327326
if kv_entry:
328327
updated_kv_entries.append(kv_entry)
329-
updated_kv_cache = kv_utils.KVCacheBase(tuple(updated_kv_entries))
328+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
330329
if export_config is not None:
331330
if (
332331
torch.numel(input_pos) > 1

ai_edge_torch/generative/examples/gemma3/verify_util.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from typing import List, Optional, Tuple
2121

2222
from ai_edge_torch.generative.examples.gemma3 import gemma3
23+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
2324
import ai_edge_torch.generative.layers.attention_utils as attn_utils
24-
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
2525
from ai_edge_torch.generative.utilities.experimental import verifier
2626
from gemma import config as gemma_config
2727
from gemma import model as gemma_model
@@ -94,7 +94,9 @@ class UnifiedGemma3Wrapper(verifier.ReauthoredModelWrapper):
9494

9595
def _init_kv_cache(self):
9696
"""Returns an initialized KV cache."""
97-
return kv_utils.KVCacheTransposed.from_model_config(self.model.model.config)
97+
return kv_utils.KVCache.from_model_config(
98+
self.model.model.config, kv_layout=kv_utils.KV_LAYOUT_TRANSPOSED
99+
)
98100

99101
def forward(
100102
self, tokens: torch.Tensor, pixel_values: torch.Tensor = None

ai_edge_torch/generative/layers/experimental/attention.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
from typing import Optional, Tuple, Union
2323

2424
from ai_edge_torch.generative.layers import builder
25+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
2526
from ai_edge_torch.generative.layers import lora as lora_utils
26-
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
27+
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
2728
from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
2829
import ai_edge_torch.generative.layers.model_config as cfg
2930
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
@@ -69,17 +70,17 @@ def forward(
6970
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
7071
mask: Optional[torch.Tensor] = None,
7172
input_pos: Optional[torch.Tensor] = None,
72-
kv_cache: kv_utils.KVCacheEntryBase = None,
73+
kv_cache: kv_utils.KVCacheEntry = None,
7374
lora: Optional[lora_utils.LoRAEntry] = None,
74-
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntryBase]]:
75+
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
7576
"""Forward function of the TransformerBlock.
7677
7778
Args:
7879
x (torch.Tensor): the input tensor.
7980
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
8081
mask (torch.Tensor): the optional mask tensor.
8182
input_pos (torch.Tensor): the optional input position tensor.
82-
kv_cache (KVCacheEntryBase): the optional kv cache entry.
83+
kv_cache (KVCacheEntry): the optional kv cache entry.
8384
lora (LoRAEntry): the optional lora entry.
8485
8586
Returns:
@@ -154,9 +155,9 @@ def forward(
154155
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
155156
mask: Optional[torch.Tensor] = None,
156157
input_pos: Optional[torch.Tensor] = None,
157-
kv_cache: Optional[kv_utils.KVCacheEntryBase] = None,
158+
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
158159
lora: Optional[lora_utils.LoRAEntry] = None,
159-
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntryBase]]:
160+
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
160161
"""Forward function of the CausalSelfAttention layer, which can support
161162
162163
MQA, GQA and MHA.
@@ -166,8 +167,7 @@ def forward(
166167
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
167168
mask (torch.Tensor): the optional mask tensor.
168169
input_pos (torch.Tensor): the optional input position tensor.
169-
kv_cache (KVCacheEntryBase): the KV cache entry corresponding to this
170-
module.
170+
kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
171171
lora (LoRAEntry): the optional lora entry.
172172
173173
Returns:
@@ -237,7 +237,7 @@ def forward(
237237
) # 1, bk, h, s
238238

239239
if kv_cache is not None:
240-
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
240+
kv_cache = kv_utils_experimental.update(kv_cache, input_pos, k, v)
241241
k, v = kv_cache.k_cache, kv_cache.v_cache
242242

243243
sdpa_out = self.sdpa_func(

0 commit comments

Comments
 (0)