1818import dataclasses
1919from typing import List , Tuple
2020
21- from ai_edge_torch import hlfb
2221from ai_edge_torch .generative .layers import model_config
2322from ai_edge_torch .generative .utilities .dynamic_update_slice import dynamic_update_slice
2423import torch
2524import torch .utils ._pytree as pytree
2625
27- BATCH_SIZE = 1
28-
2926
3027@dataclasses .dataclass
3128class KVCacheEntry :
@@ -45,9 +42,10 @@ def from_model_config(
4542 config : model_config .AttentionConfig ,
4643 dtype : torch .dtype = torch .float32 ,
4744 device : torch .device = None ,
45+ batch_size : int = 1 ,
4846 ) -> "KVCacheEntry" :
4947 """Build an instance of the class based on model config."""
50- shape = (BATCH_SIZE , kv_cache_max , config .num_query_groups , config .head_dim )
48+ shape = (batch_size , kv_cache_max , config .num_query_groups , config .head_dim )
5149 k = torch .zeros (shape , dtype = dtype , device = device )
5250 v = torch .zeros (shape , dtype = dtype , device = device )
5351 obj = cls (k_cache = k , v_cache = v )
@@ -66,6 +64,7 @@ def from_model_config(
6664 config : model_config .ModelConfig ,
6765 dtype : torch .dtype = torch .float32 ,
6866 device : torch .device = None ,
67+ batch_size : int = 1 ,
6968 ) -> "KVCache" :
7069 """Build an instance of the class based on model config.
7170
@@ -75,17 +74,21 @@ def from_model_config(
7574 Defaults to torch.float32.
7675 device (torch.device, optional): The device placement of the cache
7776 tensors. Defaults to None.
77+ batch_size (int, optional): The batch size of the cache tensors.
78+ Defaults to 1.
7879
7980 Returns:
8081 KVCache: The created cache object.
8182 """
8283 caches = [
8384 KVCacheEntry .from_model_config (
84- config .kv_cache_max if not config .block_config (idx ).kv_cache_max_len
85+ config .kv_cache_max
86+ if not config .block_config (idx ).kv_cache_max_len
8587 else config .block_config (idx ).kv_cache_max_len ,
8688 config .block_config (idx ).attn_config ,
8789 dtype ,
8890 device ,
91+ batch_size ,
8992 )
9093 for idx in range (config .num_layers )
9194 ]
0 commit comments