Skip to content

Commit 51f7614

Browse files
haozha111copybara-github
authored andcommitted
Introduce sdpa_with_kv_update function. It will perform different cache update, pre/post sdpa logic based on kv cache layout.
PiperOrigin-RevId: 745609775
1 parent 44ed8bb commit 51f7614

File tree

5 files changed

+130
-36
lines changed

5 files changed

+130
-36
lines changed

ai_edge_torch/generative/layers/experimental/attention.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
from ai_edge_torch.generative.layers import builder
2525
from ai_edge_torch.generative.layers import kv_cache as kv_utils
2626
from ai_edge_torch.generative.layers import lora as lora_utils
27-
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
28-
from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
27+
from ai_edge_torch.generative.layers import sdpa_with_kv_update
2928
import ai_edge_torch.generative.layers.model_config as cfg
3029
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
3130
import torch
@@ -147,7 +146,6 @@ def __init__(
147146
self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
148147
self.config = config
149148
self.enable_hlfb = enable_hlfb
150-
self.sdpa_func = sdpa.scaled_dot_product_attention
151149

152150
def forward(
153151
self,
@@ -221,36 +219,8 @@ def forward(
221219
cos, sin = rope
222220
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
223221

224-
# Transpose k/v to specific layout for GPU implementation.
225-
b, _, n, h = q.shape
226-
g = n // self.config.num_query_groups
227-
# btnh -> bnth -> b(kg)th -> 1(bk)(gt)h
228-
q = q.permute(0, 2, 1, 3).reshape(
229-
1, b * self.config.num_query_groups, g * T, h
230-
)
231-
232-
k = k.permute(0, 2, 1, 3).reshape(
233-
1, -1, T, self.config.head_dim
234-
) # 1, bk, s, h
235-
v = v.permute(0, 2, 3, 1).reshape(
236-
1, -1, self.config.head_dim, T
237-
) # 1, bk, h, s
238-
239-
if kv_cache is not None:
240-
kv_cache = kv_utils_experimental.update(kv_cache, input_pos, k, v)
241-
k, v = kv_cache.k_cache, kv_cache.v_cache
242-
243-
sdpa_out = self.sdpa_func(
244-
kv_cache,
245-
q,
246-
k,
247-
v,
248-
self.config.head_dim,
249-
mask=mask,
250-
softcap=self.config.logit_softcap,
251-
) # 1, bk, gt, h
252-
sdpa_out = (
253-
sdpa_out.reshape(B, -1, T, h).permute(0, 2, 1, 3).reshape(B, T, -1)
222+
sdpa_out, kv_cache = sdpa_with_kv_update.sdpa_with_kv_update(
223+
q, k, v, kv_cache, input_pos, mask, self.config
254224
)
255225

256226
# Compute the output projection.

ai_edge_torch/generative/layers/experimental/kv_cache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def update(
4444
assert (
4545
cache.kv_layout == kv_utils.KV_LAYOUT_TRANSPOSED
4646
), "KV entry must have transposed layout."
47-
return _update_kv_impl_transposed(cache, input_pos, k_slice, v_slice)
47+
update_kv_cache = _update_kv_impl_transposed
48+
return update_kv_cache(cache, input_pos, k_slice, v_slice)
4849

4950

5051
def _get_slice_indices(

ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def _sdpa(k_type, v_type, *args, **kwargs):
8282
padded_logits = logits + mask
8383
padded_logits = padded_logits.reshape(1, bk, gt, s)
8484
probs = F.softmax(padded_logits, dim=-1).type_as(key)
85-
8685
encoded = bmm_lib.bmm_4d(probs, value)
8786

8887
return encoded # 1, bk, gt, h
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 2024 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
# Common utility functions for data loading etc.
16+
from dataclasses import dataclass
17+
from typing import Tuple
18+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
19+
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa_default
20+
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
21+
from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
22+
from ai_edge_torch.generative.layers.experimental import types
23+
import ai_edge_torch.generative.layers.model_config as cfg
24+
from multipledispatch import dispatch
25+
import torch
26+
27+
28+
def sdpa_with_kv_update(
29+
query: torch.Tensor,
30+
key: torch.Tensor,
31+
value: torch.Tensor,
32+
kv: kv_utils.KVCacheEntry,
33+
input_pos: torch.Tensor,
34+
mask: torch.Tensor,
35+
config: cfg.AttentionConfig,
36+
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
37+
return sdpa_with_kv_update_impl(
38+
kv.kv_layout[0](), # key layout
39+
kv.kv_layout[1](), # value layout
40+
query=query,
41+
key=key,
42+
value=value,
43+
kv=kv,
44+
input_pos=input_pos,
45+
mask=mask,
46+
config=config,
47+
)
48+
49+
50+
@dispatch(types.BNTH, types.BNHT)
51+
def sdpa_with_kv_update_impl(
52+
k_type, v_type, *args, **kwargs
53+
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
54+
query = kwargs["query"]
55+
key = kwargs["key"]
56+
value = kwargs["value"]
57+
kv = kwargs["kv"]
58+
input_pos = kwargs["input_pos"]
59+
mask = kwargs["mask"]
60+
config = kwargs["config"]
61+
62+
# Transpose k/v to specific layout for GPU implementation.
63+
b, seq_len, n, h = query.shape
64+
g = n // config.num_query_groups
65+
# btnh -> bnth -> b(kg)th -> 1(bk)(gt)h
66+
query = query.permute(0, 2, 1, 3).reshape(
67+
1, b * config.num_query_groups, g * seq_len, h
68+
)
69+
70+
key = key.permute(0, 2, 1, 3).reshape(
71+
1, -1, seq_len, config.head_dim
72+
) # 1, bk, s, h
73+
value = value.permute(0, 2, 3, 1).reshape(
74+
1, -1, config.head_dim, seq_len
75+
) # 1, bk, h, s
76+
77+
if kv is not None:
78+
kv = kv_utils_experimental.update(kv, input_pos, key, value)
79+
key, value = kv.k_cache, kv.v_cache
80+
81+
sdpa_out = sdpa.scaled_dot_product_attention(
82+
kv,
83+
query,
84+
key,
85+
value,
86+
config.head_dim,
87+
mask=mask,
88+
softcap=config.logit_softcap,
89+
) # 1, bk, gt, h
90+
sdpa_out = (
91+
sdpa_out.reshape(b, -1, seq_len, h)
92+
.permute(0, 2, 1, 3)
93+
.reshape(b, seq_len, -1)
94+
)
95+
return sdpa_out, kv
96+
97+
98+
@dispatch(object, object)
99+
def sdpa_with_kv_update_impl(
100+
k_type, v_type, *args, **kwargs
101+
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
102+
query = kwargs["query"]
103+
key = kwargs["key"]
104+
value = kwargs["value"]
105+
kv = kwargs["kv"]
106+
input_pos = kwargs["input_pos"]
107+
mask = kwargs["mask"]
108+
config = kwargs["config"]
109+
110+
b, seq_len, _, _ = query.shape
111+
if kv is not None:
112+
kv = kv_utils.update(kv, input_pos, key, value)
113+
key, value = kv.k_cache, kv.v_cache
114+
115+
sdpa_out = sdpa_default.scaled_dot_product_attention(
116+
query,
117+
key,
118+
value,
119+
config.head_dim,
120+
mask=mask,
121+
softcap=config.logit_softcap,
122+
)
123+
sdpa_out = sdpa_out.reshape(b, seq_len, -1)
124+
return sdpa_out, kv

ai_edge_torch/generative/utilities/experimental/verifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Any, List, Optional
2020

2121
from ai_edge_torch.generative.layers import kv_cache as kv_utils
22-
from ai_edge_torch.generative.utilities import export_config
22+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
2323
import torch
2424

2525
ExportConfig = export_config.ExportConfig

0 commit comments

Comments
 (0)