Skip to content

Commit 883075e

Browse files
hheydarycopybara-github
authored andcommitted
Internal code change.
PiperOrigin-RevId: 731863023
1 parent 6d58005 commit 883075e

File tree

2 files changed

+45
-39
lines changed

2 files changed

+45
-39
lines changed

ai_edge_torch/generative/layers/experimental/attention.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def __init__(
5252
config.pre_attention_norm_config,
5353
)
5454
self.atten_func = CausalSelfAttention(
55-
model_config.batch_size,
5655
model_config.embedding_dim,
5756
config.attn_config,
5857
model_config.enable_hlfb,
@@ -119,22 +118,19 @@ class CausalSelfAttention(nn.Module):
119118

120119
def __init__(
121120
self,
122-
batch_size: int,
123121
dim: int,
124122
config: cfg.AttentionConfig,
125123
enable_hlfb: bool,
126124
) -> None:
127125
"""Initialize an instance of CausalSelfAttention.
128126
129127
Args:
130-
batch_size (int): batch size of the input tensor.
131128
dim (int): causal attention's input/output dimmension.
132129
config (cfg.AttentionConfig): attention specific configurations.
133130
enable_hlfb (bool): whether hlfb is enabled or not.
134131
"""
135132
super().__init__()
136133
self.kv_cache = None
137-
self.batch_size = batch_size
138134
qkv_shape = (
139135
config.num_heads + 2 * config.num_query_groups
140136
) * config.head_dim
@@ -180,10 +176,6 @@ def forward(
180176
"""
181177
# Batch size, sequence length, embedding dimensionality.
182178
B, T, E = x.size()
183-
assert B == self.batch_size, (
184-
"batch size of input tensor must match with the batch size specified in"
185-
" the model configuration."
186-
)
187179

188180
qkv = self.qkv_projection(x)
189181

ai_edge_torch/generative/layers/experimental/kv_cache.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,19 @@
2121
import dataclasses
2222
from typing import List, Tuple
2323

24-
from ai_edge_torch import hlfb
2524
from ai_edge_torch.generative.layers import model_config
26-
from ai_edge_torch.generative.layers.experimental import types as types
27-
from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
25+
from ai_edge_torch.generative.layers.experimental import types
26+
from ai_edge_torch.generative.utilities import dynamic_update_slice as dus_utils
2827
import torch
29-
import torch.nn as nn
3028
import torch.utils._pytree as pytree
3129

32-
BATCH_SIZE = 1
33-
3430

3531
@dataclasses.dataclass
3632
class KVCacheEntryBase:
3733
"""A single cache entry that includes K and V caches.
3834
3935
The chaches are built based on the provided config with the shape of
40-
(batch_size=1, kv_cache_max, num_query_groups, head_dim).
36+
(batch_size, kv_cache_max, num_query_groups, head_dim).
4137
"""
4238

4339
k_cache: torch.Tensor
@@ -46,10 +42,8 @@ class KVCacheEntryBase:
4642
@classmethod
4743
def _from_model_config(
4844
cls,
49-
kv_cache_max: int,
50-
config: model_config.AttentionConfig,
51-
k_shape: Tuple,
52-
v_shape: Tuple,
45+
k_shape: Tuple[int, ...],
46+
v_shape: Tuple[int, ...],
5347
dtype: torch.dtype = torch.float32,
5448
device: torch.device = None,
5549
) -> "KVCacheEntryBase":
@@ -66,12 +60,11 @@ def from_model_config(
6660
config: model_config.AttentionConfig,
6761
dtype: torch.dtype = torch.float32,
6862
device: torch.device = None,
63+
batch_size: int = 1,
6964
) -> "KVCacheEntryBase":
7065
"""Build an instance of the class based on model config."""
71-
shape = (BATCH_SIZE, kv_cache_max, config.num_query_groups, config.head_dim)
72-
return cls._from_model_config(
73-
kv_cache_max, config, shape, shape, dtype, device
74-
)
66+
shape = (batch_size, kv_cache_max, config.num_query_groups, config.head_dim)
67+
return cls._from_model_config(shape, shape, dtype, device)
7568

7669

7770
@dataclasses.dataclass
@@ -93,24 +86,22 @@ def from_model_config(
9386
config: model_config.AttentionConfig,
9487
dtype: torch.dtype = torch.float32,
9588
device: torch.device = None,
89+
batch_size: int = 1,
9690
) -> "KVCacheEntryBase":
9791
"""Build an instance of the class based on model config."""
98-
num_kv_heads = config.num_query_groups
9992
k_shape = (
100-
1,
101-
BATCH_SIZE * num_kv_heads,
93+
batch_size,
94+
config.num_query_groups,
10295
kv_cache_max,
10396
config.head_dim,
104-
) # 1, bk, s, h
97+
) # b, k, s, h
10598
v_shape = (
106-
1,
107-
BATCH_SIZE * num_kv_heads,
99+
batch_size,
100+
config.num_query_groups,
108101
config.head_dim,
109102
kv_cache_max,
110-
) # 1, bk, h, s
111-
return cls._from_model_config(
112-
kv_cache_max, config, k_shape, v_shape, dtype, device
113-
)
103+
) # b, k, h, s
104+
return cls._from_model_config(k_shape, v_shape, dtype, device)
114105

115106

116107
@dataclasses.dataclass
@@ -126,13 +117,15 @@ def _from_model_config(
126117
config: model_config.ModelConfig,
127118
dtype: torch.dtype = torch.float32,
128119
device: torch.device = None,
120+
batch_size: int = 1,
129121
) -> "KVCacheBase":
130122
caches = [
131123
kv_entry_cls.from_model_config(
132124
config.kv_cache_max,
133125
config.block_config(idx).attn_config,
134126
dtype,
135127
device,
128+
batch_size,
136129
)
137130
for idx in range(config.num_layers)
138131
]
@@ -145,6 +138,7 @@ def from_model_config(
145138
config: model_config.ModelConfig,
146139
dtype: torch.dtype = torch.float32,
147140
device: torch.device = None,
141+
batch_size: int = 1,
148142
) -> "KVCacheBase":
149143
"""Build an instance of the class based on model config.
150144
@@ -154,12 +148,19 @@ def from_model_config(
154148
Defaults to torch.float32.
155149
device (torch.device, optional): The device placement of the cache
156150
tensors. Defaults to None.
151+
batch_size (int, optional): The batch size of the cache tensors.
152+
Defaults to 1.
157153
158154
Returns:
159155
KVCacheBase: The created cache object.
160156
"""
157+
assert batch_size == 1, "Batch size must be 1 for KV Cache."
161158
return cls._from_model_config(
162-
KVCacheEntryBase, config=config, dtype=dtype, device=device
159+
KVCacheEntryBase,
160+
config=config,
161+
dtype=dtype,
162+
device=device,
163+
batch_size=batch_size,
163164
)
164165

165166
def flatten(self) -> List[torch.Tensor]:
@@ -177,9 +178,14 @@ def from_model_config(
177178
config: model_config.ModelConfig,
178179
dtype: torch.dtype = torch.float32,
179180
device: torch.device = None,
181+
batch_size: int = 1,
180182
) -> "KVCacheBTNH":
181183
return cls._from_model_config(
182-
KVCacheEntryBTNH, config=config, dtype=dtype, device=device
184+
KVCacheEntryBTNH,
185+
config=config,
186+
dtype=dtype,
187+
device=device,
188+
batch_size=batch_size,
183189
)
184190

185191

@@ -192,9 +198,14 @@ def from_model_config(
192198
config: model_config.ModelConfig,
193199
dtype: torch.dtype = torch.float32,
194200
device: torch.device = None,
201+
batch_size: int = 1,
195202
) -> "KVCacheBTNH":
196203
return cls._from_model_config(
197-
KVCacheEntryTransposed, config=config, dtype=dtype, device=device
204+
KVCacheEntryTransposed,
205+
config=config,
206+
dtype=dtype,
207+
device=device,
208+
batch_size=batch_size,
198209
)
199210

200211

@@ -258,7 +269,6 @@ def update(
258269
input_pos: torch.Tensor,
259270
k_slice: torch.Tensor,
260271
v_slice: torch.Tensor,
261-
use_dus: bool = True,
262272
) -> KVCacheEntryBase:
263273
"""Out of place update of Cache buffer.
264274
@@ -309,6 +319,10 @@ def _update_kv_impl(
309319
positions = input_pos.clone()
310320
k_slice_indices = _get_slice_indices(positions, cache_dim, k_ts_idx)
311321
v_slice_indices = _get_slice_indices(positions, cache_dim, v_ts_idx)
312-
k = dynamic_update_slice(cache.k_cache, k_slice, [x for x in k_slice_indices])
313-
v = dynamic_update_slice(cache.v_cache, v_slice, [x for x in v_slice_indices])
322+
k = dus_utils.dynamic_update_slice(
323+
cache.k_cache, k_slice, [x for x in k_slice_indices]
324+
)
325+
v = dus_utils.dynamic_update_slice(
326+
cache.v_cache, v_slice, [x for x in v_slice_indices]
327+
)
314328
return KVCacheEntryTransposed(k, v)

0 commit comments

Comments
 (0)