2121import  dataclasses 
2222from  typing  import  List , Tuple 
2323
24- from  ai_edge_torch  import  hlfb 
2524from  ai_edge_torch .generative .layers  import  model_config 
26- from  ai_edge_torch .generative .layers .experimental  import  types   as   types 
27- from  ai_edge_torch .generative .utilities . dynamic_update_slice  import  dynamic_update_slice 
25+ from  ai_edge_torch .generative .layers .experimental  import  types 
26+ from  ai_edge_torch .generative .utilities  import  dynamic_update_slice   as   dus_utils 
2827import  torch 
29- import  torch .nn  as  nn 
3028import  torch .utils ._pytree  as  pytree 
3129
32- BATCH_SIZE  =  1 
33- 
3430
3531@dataclasses .dataclass  
3632class  KVCacheEntryBase :
3733  """A single cache entry that includes K and V caches. 
3834
3935  The chaches are built based on the provided config with the shape of 
40-   (batch_size=1 , kv_cache_max, num_query_groups, head_dim). 
36+   (batch_size, kv_cache_max, num_query_groups, head_dim). 
4137  """ 
4238
4339  k_cache : torch .Tensor 
@@ -46,10 +42,8 @@ class KVCacheEntryBase:
4642  @classmethod  
4743  def  _from_model_config (
4844      cls ,
49-       kv_cache_max : int ,
50-       config : model_config .AttentionConfig ,
51-       k_shape : Tuple ,
52-       v_shape : Tuple ,
45+       k_shape : Tuple [int , ...],
46+       v_shape : Tuple [int , ...],
5347      dtype : torch .dtype  =  torch .float32 ,
5448      device : torch .device  =  None ,
5549  ) ->  "KVCacheEntryBase" :
@@ -66,12 +60,11 @@ def from_model_config(
6660      config : model_config .AttentionConfig ,
6761      dtype : torch .dtype  =  torch .float32 ,
6862      device : torch .device  =  None ,
63+       batch_size : int  =  1 ,
6964  ) ->  "KVCacheEntryBase" :
7065    """Build an instance of the class based on model config.""" 
71-     shape  =  (BATCH_SIZE , kv_cache_max , config .num_query_groups , config .head_dim )
72-     return  cls ._from_model_config (
73-         kv_cache_max , config , shape , shape , dtype , device 
74-     )
66+     shape  =  (batch_size , kv_cache_max , config .num_query_groups , config .head_dim )
67+     return  cls ._from_model_config (shape , shape , dtype , device )
7568
7669
7770@dataclasses .dataclass  
@@ -93,24 +86,22 @@ def from_model_config(
9386      config : model_config .AttentionConfig ,
9487      dtype : torch .dtype  =  torch .float32 ,
9588      device : torch .device  =  None ,
89+       batch_size : int  =  1 ,
9690  ) ->  "KVCacheEntryBase" :
9791    """Build an instance of the class based on model config.""" 
98-     num_kv_heads  =  config .num_query_groups 
9992    k_shape  =  (
100-         1 ,
101-         BATCH_SIZE   *   num_kv_heads ,
93+         batch_size ,
94+         config . num_query_groups ,
10295        kv_cache_max ,
10396        config .head_dim ,
104-     )  # 1, bk , s, h 
97+     )  # b, k , s, h 
10598    v_shape  =  (
106-         1 ,
107-         BATCH_SIZE   *   num_kv_heads ,
99+         batch_size ,
100+         config . num_query_groups ,
108101        config .head_dim ,
109102        kv_cache_max ,
110-     )  # 1, bk, h, s 
111-     return  cls ._from_model_config (
112-         kv_cache_max , config , k_shape , v_shape , dtype , device 
113-     )
103+     )  # b, k, h, s 
104+     return  cls ._from_model_config (k_shape , v_shape , dtype , device )
114105
115106
116107@dataclasses .dataclass  
@@ -126,13 +117,15 @@ def _from_model_config(
126117      config : model_config .ModelConfig ,
127118      dtype : torch .dtype  =  torch .float32 ,
128119      device : torch .device  =  None ,
120+       batch_size : int  =  1 ,
129121  ) ->  "KVCacheBase" :
130122    caches  =  [
131123        kv_entry_cls .from_model_config (
132124            config .kv_cache_max ,
133125            config .block_config (idx ).attn_config ,
134126            dtype ,
135127            device ,
128+             batch_size ,
136129        )
137130        for  idx  in  range (config .num_layers )
138131    ]
@@ -145,6 +138,7 @@ def from_model_config(
145138      config : model_config .ModelConfig ,
146139      dtype : torch .dtype  =  torch .float32 ,
147140      device : torch .device  =  None ,
141+       batch_size : int  =  1 ,
148142  ) ->  "KVCacheBase" :
149143    """Build an instance of the class based on model config. 
150144
@@ -154,12 +148,19 @@ def from_model_config(
154148          Defaults to torch.float32. 
155149        device (torch.device, optional): The device placement of the cache 
156150          tensors. Defaults to None. 
151+         batch_size (int, optional): The batch size of the cache tensors. 
152+           Defaults to 1. 
157153
158154    Returns: 
159155        KVCacheBase: The created cache object. 
160156    """ 
157+     assert  batch_size  ==  1 , "Batch size must be 1 for KV Cache." 
161158    return  cls ._from_model_config (
162-         KVCacheEntryBase , config = config , dtype = dtype , device = device 
159+         KVCacheEntryBase ,
160+         config = config ,
161+         dtype = dtype ,
162+         device = device ,
163+         batch_size = batch_size ,
163164    )
164165
165166  def  flatten (self ) ->  List [torch .Tensor ]:
@@ -177,9 +178,14 @@ def from_model_config(
177178      config : model_config .ModelConfig ,
178179      dtype : torch .dtype  =  torch .float32 ,
179180      device : torch .device  =  None ,
181+       batch_size : int  =  1 ,
180182  ) ->  "KVCacheBTNH" :
181183    return  cls ._from_model_config (
182-         KVCacheEntryBTNH , config = config , dtype = dtype , device = device 
184+         KVCacheEntryBTNH ,
185+         config = config ,
186+         dtype = dtype ,
187+         device = device ,
188+         batch_size = batch_size ,
183189    )
184190
185191
@@ -192,9 +198,14 @@ def from_model_config(
192198      config : model_config .ModelConfig ,
193199      dtype : torch .dtype  =  torch .float32 ,
194200      device : torch .device  =  None ,
201+       batch_size : int  =  1 ,
195202  ) ->  "KVCacheBTNH" :
196203    return  cls ._from_model_config (
197-         KVCacheEntryTransposed , config = config , dtype = dtype , device = device 
204+         KVCacheEntryTransposed ,
205+         config = config ,
206+         dtype = dtype ,
207+         device = device ,
208+         batch_size = batch_size ,
198209    )
199210
200211
@@ -258,7 +269,6 @@ def update(
258269    input_pos : torch .Tensor ,
259270    k_slice : torch .Tensor ,
260271    v_slice : torch .Tensor ,
261-     use_dus : bool  =  True ,
262272) ->  KVCacheEntryBase :
263273  """Out of place update of Cache buffer. 
264274
@@ -309,6 +319,10 @@ def _update_kv_impl(
309319  positions  =  input_pos .clone ()
310320  k_slice_indices  =  _get_slice_indices (positions , cache_dim , k_ts_idx )
311321  v_slice_indices  =  _get_slice_indices (positions , cache_dim , v_ts_idx )
312-   k  =  dynamic_update_slice (cache .k_cache , k_slice , [x  for  x  in  k_slice_indices ])
313-   v  =  dynamic_update_slice (cache .v_cache , v_slice , [x  for  x  in  v_slice_indices ])
322+   k  =  dus_utils .dynamic_update_slice (
323+       cache .k_cache , k_slice , [x  for  x  in  k_slice_indices ]
324+   )
325+   v  =  dus_utils .dynamic_update_slice (
326+       cache .v_cache , v_slice , [x  for  x  in  v_slice_indices ]
327+   )
314328  return  KVCacheEntryTransposed (k , v )
0 commit comments