Skip to content

Commit f46489c

Browse files
protobird-gitcopybara-github
authored andcommitted
Call sdpa_with_kv_update from non-experimental attention
- Replace multidispatch with an explicit if clause for strong-typed function calls. - Add unittests for transposed KV cache. PiperOrigin-RevId: 750707606
1 parent 9d83b50 commit f46489c

File tree

3 files changed

+84
-67
lines changed

3 files changed

+84
-67
lines changed

ai_edge_torch/generative/layers/attention.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ai_edge_torch.generative.layers import kv_cache as kv_utils
2222
from ai_edge_torch.generative.layers import lora as lora_utils
2323
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
24+
from ai_edge_torch.generative.layers import sdpa_with_kv_update
2425
import ai_edge_torch.generative.layers.model_config as cfg
2526
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
2627
import torch
@@ -142,11 +143,6 @@ def __init__(
142143
self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
143144
self.config = config
144145
self.enable_hlfb = enable_hlfb
145-
self.sdpa_func = (
146-
sdpa.scaled_dot_product_attention_with_hlfb
147-
if enable_hlfb
148-
else sdpa.scaled_dot_product_attention
149-
)
150146

151147
def forward(
152148
self,
@@ -174,7 +170,7 @@ def forward(
174170
KV Cach Entry (if passed in).
175171
"""
176172
# Batch size, sequence length, embedding dimensionality.
177-
B, T, E = x.size()
173+
B, T, _ = x.size()
178174
qkv = self.qkv_projection(x)
179175

180176
# Assemble into a number of query groups to support MHA, MQA and GQA.
@@ -218,19 +214,9 @@ def forward(
218214
cos, sin = rope
219215
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
220216

221-
if kv_cache is not None:
222-
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
223-
k, v = kv_cache.k_cache, kv_cache.v_cache
224-
225-
sdpa_out = self.sdpa_func(
226-
q,
227-
k,
228-
v,
229-
self.config.head_dim,
230-
mask=mask,
231-
softcap=self.config.logit_softcap,
217+
sdpa_out, kv_cache = sdpa_with_kv_update.sdpa_with_kv_update(
218+
q, k, v, kv_cache, input_pos, mask, self.config, self.enable_hlfb
232219
)
233-
sdpa_out = sdpa_out.reshape(B, T, -1)
234220

235221
# Compute the output projection.
236222
y = self.output_projection(sdpa_out)

ai_edge_torch/generative/layers/sdpa_with_kv_update.py

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
# Common utility functions for data loading etc.
16-
from dataclasses import dataclass
15+
16+
"""Common utility functions for data loading etc."""
17+
1718
from typing import Tuple
19+
1820
from ai_edge_torch.generative.layers import kv_cache as kv_utils
1921
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa_default
2022
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
2123
from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
2224
import ai_edge_torch.generative.layers.model_config as cfg
23-
from ai_edge_torch.generative.utilities import types
24-
from multipledispatch import dispatch
2525
import torch
2626

2727

@@ -33,32 +33,27 @@ def sdpa_with_kv_update(
3333
input_pos: torch.Tensor,
3434
mask: torch.Tensor,
3535
config: cfg.AttentionConfig,
36+
enable_hlfb: bool,
3637
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
37-
return sdpa_with_kv_update_impl(
38-
kv.kv_layout[0](), # key layout
39-
kv.kv_layout[1](), # value layout
40-
query=query,
41-
key=key,
42-
value=value,
43-
kv=kv,
44-
input_pos=input_pos,
45-
mask=mask,
46-
config=config,
38+
"""Wrapper function for scaled dot product attention with KV cache update."""
39+
if kv is not None and kv.kv_layout == kv_utils.KV_LAYOUT_TRANSPOSED:
40+
return _sdpa_with_kv_update_transposed(
41+
query, key, value, kv, input_pos, mask, config
42+
)
43+
return _sdpa_with_kv_update_default(
44+
query, key, value, kv, input_pos, mask, config, enable_hlfb
4745
)
4846

4947

50-
@dispatch(types.BNTH, types.BNHT)
51-
def sdpa_with_kv_update_impl(
52-
k_type, v_type, *args, **kwargs
48+
def _sdpa_with_kv_update_transposed(
49+
query: torch.Tensor,
50+
key: torch.Tensor,
51+
value: torch.Tensor,
52+
kv: kv_utils.KVCacheEntry,
53+
input_pos: torch.Tensor,
54+
mask: torch.Tensor,
55+
config: cfg.AttentionConfig,
5356
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
54-
query = kwargs["query"]
55-
key = kwargs["key"]
56-
value = kwargs["value"]
57-
kv = kwargs["kv"]
58-
input_pos = kwargs["input_pos"]
59-
mask = kwargs["mask"]
60-
config = kwargs["config"]
61-
6257
# Transpose k/v to specific layout for GPU implementation.
6358
b, seq_len, n, h = query.shape
6459
g = n // config.num_query_groups
@@ -74,9 +69,8 @@ def sdpa_with_kv_update_impl(
7469
1, -1, config.head_dim, seq_len
7570
) # 1, bk, h, s
7671

77-
if kv is not None:
78-
kv = kv_utils_experimental.update(kv, input_pos, key, value)
79-
key, value = kv.k_cache, kv.v_cache
72+
kv = kv_utils_experimental.update(kv, input_pos, key, value)
73+
key, value = kv.k_cache, kv.v_cache
8074

8175
sdpa_out = sdpa.scaled_dot_product_attention(
8276
kv,
@@ -95,24 +89,26 @@ def sdpa_with_kv_update_impl(
9589
return sdpa_out, kv
9690

9791

98-
@dispatch(object, object)
99-
def sdpa_with_kv_update_impl(
100-
k_type, v_type, *args, **kwargs
92+
def _sdpa_with_kv_update_default(
93+
query: torch.Tensor,
94+
key: torch.Tensor,
95+
value: torch.Tensor,
96+
kv: kv_utils.KVCacheEntry,
97+
input_pos: torch.Tensor,
98+
mask: torch.Tensor,
99+
config: cfg.AttentionConfig,
100+
enable_hlfb: bool,
101101
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
102-
query = kwargs["query"]
103-
key = kwargs["key"]
104-
value = kwargs["value"]
105-
kv = kwargs["kv"]
106-
input_pos = kwargs["input_pos"]
107-
mask = kwargs["mask"]
108-
config = kwargs["config"]
109-
110102
b, seq_len, _, _ = query.shape
111103
if kv is not None:
112104
kv = kv_utils.update(kv, input_pos, key, value)
113105
key, value = kv.k_cache, kv.v_cache
114106

115-
sdpa_out = sdpa_default.scaled_dot_product_attention(
107+
if enable_hlfb:
108+
sdpa_func = sdpa_default.scaled_dot_product_attention_with_hlfb
109+
else:
110+
sdpa_func = sdpa_default.scaled_dot_product_attention
111+
sdpa_out = sdpa_func(
116112
query,
117113
key,
118114
value,

ai_edge_torch/generative/test/test_model_conversion.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ def setUp(self):
4141
)
4242
)
4343

44-
def _get_params(self, enable_hlfb: bool):
44+
def _get_params(self, enable_hlfb: bool, kv_layout: kv_cache.KVLayout):
4545
"""Returns a model, edge model and the kwargs to use for testing."""
4646
config = toy_model_with_kv_cache.get_model_config()
4747
config.enable_hlfb = enable_hlfb
4848
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
4949
tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
5050
[10], dtype=torch.int
5151
)
52-
kv = kv_cache.KVCache.from_model_config(config)
52+
kv = kv_cache.KVCache.from_model_config(config, kv_layout=kv_layout)
5353
kwargs = {
5454
"tokens": tokens,
5555
"input_pos": input_pos,
@@ -65,8 +65,12 @@ def _get_params(self, enable_hlfb: bool):
6565
)
6666
return pytorch_model, edge_model, kwargs
6767

68-
def _test_model_with_kv_cache(self, enable_hlfb: bool):
69-
pytorch_model, edge_model, kwargs = self._get_params(enable_hlfb)
68+
def _test_model_with_kv_cache(
69+
self,
70+
enable_hlfb: bool = False,
71+
kv_layout: kv_cache.KVLayout = kv_cache.KV_LAYOUT_DEFAULT,
72+
):
73+
pytorch_model, edge_model, kwargs = self._get_params(enable_hlfb, kv_layout)
7074

7175
self.assertTrue(
7276
test_utils.compare_tflite_torch(
@@ -95,13 +99,22 @@ def test_toy_model_with_kv_cache(self):
9599
def test_toy_model_with_kv_cache_with_hlfb(self):
96100
self._test_model_with_kv_cache(enable_hlfb=True)
97101

102+
@googletest.skipIf(
103+
ai_edge_torch.config.in_oss,
104+
reason="tests with custom ops are not supported in oss",
105+
)
106+
def test_toy_model_with_kv_cache_transposed(self):
107+
self._test_model_with_kv_cache(kv_layout=kv_cache.KV_LAYOUT_TRANSPOSED)
108+
98109
@googletest.skipIf(
99110
ai_edge_torch.config.in_oss,
100111
reason="tests with custom ops are not supported in oss",
101112
)
102113
def test_toy_model_has_dus_op(self):
103114
"""Tests that the model has the dynamic update slice op."""
104-
_, edge_model, _ = self._get_params(enable_hlfb=True)
115+
_, edge_model, _ = self._get_params(
116+
enable_hlfb=True, kv_layout=kv_cache.KV_LAYOUT_DEFAULT
117+
)
105118
interpreter_ = interpreter.InterpreterWithCustomOps(
106119
custom_op_registerers=["GenAIOpsRegisterer"],
107120
model_content=edge_model.tflite_model(),
@@ -112,7 +125,14 @@ def test_toy_model_has_dus_op(self):
112125
op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
113126
self.assertIn("DYNAMIC_UPDATE_SLICE", op_names)
114127

115-
def _test_multisig_model(self, config, pytorch_model, atol, rtol):
128+
def _test_multisig_model(
129+
self,
130+
config,
131+
pytorch_model,
132+
atol,
133+
rtol,
134+
kv_layout=kv_cache.KV_LAYOUT_DEFAULT,
135+
):
116136
# prefill
117137
seq_len = 10
118138
prefill_tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
@@ -124,7 +144,7 @@ def _test_multisig_model(self, config, pytorch_model, atol, rtol):
124144
decode_token = torch.tensor([[1]], dtype=torch.int)
125145
decode_input_pos = torch.tensor([5], dtype=torch.int)
126146

127-
kv = kv_cache.KVCache.from_model_config(config)
147+
kv = kv_cache.KVCache.from_model_config(config, kv_layout=kv_layout)
128148

129149
edge_model = (
130150
ai_edge_torch.signature(
@@ -160,7 +180,7 @@ def _test_multisig_model(self, config, pytorch_model, atol, rtol):
160180
kv,
161181
signature_name="prefill",
162182
atol=atol,
163-
rtol=atol,
183+
rtol=rtol,
164184
)
165185
)
166186

@@ -173,7 +193,7 @@ def _test_multisig_model(self, config, pytorch_model, atol, rtol):
173193
kv,
174194
signature_name="decode",
175195
atol=atol,
176-
rtol=atol,
196+
rtol=rtol,
177197
)
178198
)
179199

@@ -186,6 +206,21 @@ def test_tiny_llama_multisig(self):
186206
pytorch_model = tiny_llama.TinyLlama(config).eval()
187207
self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
188208

209+
@googletest.skipIf(
210+
ai_edge_torch.config.in_oss,
211+
reason="tests with custom ops are not supported in oss",
212+
)
213+
def test_tiny_llama_multisig_kv_layout_transposed(self):
214+
config = tiny_llama.get_fake_model_config()
215+
pytorch_model = tiny_llama.TinyLlama(config).eval()
216+
self._test_multisig_model(
217+
config,
218+
pytorch_model,
219+
atol=1e-5,
220+
rtol=1e-5,
221+
kv_layout=kv_cache.KV_LAYOUT_TRANSPOSED,
222+
)
223+
189224

190225
if __name__ == "__main__":
191226
googletest.main()

0 commit comments

Comments
 (0)