Skip to content

Commit 5eb4c6f

Browse files
committed
Update on "Changes to split kv cache and sdpa"
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent b981f06 commit 5eb4c6f

File tree

1 file changed

+10
-22
lines changed
  • examples/models/llama/source_transformation

1 file changed

+10
-22
lines changed

examples/models/llama/source_transformation/sdpa.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,11 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
8282
class SDPASimple(torch.nn.Module):
8383
def __init__(
8484
self,
85-
kv_cache: KVCache,
8685
dim: int,
8786
head_dim: int,
8887
n_rep: int,
8988
):
9089
super().__init__()
91-
self.kv_cache = kv_cache
9290
self.dim = dim
9391
self.head_dim = head_dim
9492
self.n_rep = n_rep
@@ -103,11 +101,6 @@ def forward(
103101
seqlen,
104102
mask,
105103
):
106-
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
107-
k = k.transpose(1, 2)
108-
v = v.transpose(1, 2)
109-
110-
k, v = self.kv_cache.update(input_pos, k, v)
111104
attn_mask = mask[None, None, input_pos]
112105

113106
k = k.repeat_interleave(self.n_rep, dim=1)
@@ -141,12 +134,10 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
141134
class SDPAFlex(torch.nn.Module):
142135
def __init__(
143136
self,
144-
kv_cache: KVCache,
145137
dim: int,
146138
n_rep: int,
147139
):
148140
super().__init__()
149-
self.kv_cache = kv_cache
150141
self.dim = dim
151142
self.n_rep = n_rep
152143

@@ -160,9 +151,10 @@ def forward(
160151
seqlen,
161152
mask,
162153
):
163-
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
164-
165-
k, v = self.kv_cache.update(input_pos, k, v)
154+
"""
155+
q: (bs, n_heads, seqlen, head_dim)
156+
k, v: (bs, n_local_heads, seqlen, head_dim)
157+
"""
166158
k = repeat_kv(k, self.n_rep)
167159
v = repeat_kv(v, self.n_rep)
168160
attn_mask = mask[input_pos]
@@ -182,7 +174,7 @@ def replace_sdpa_with_simple_sdpa(module: torch.nn.Module):
182174
setattr(
183175
module,
184176
name,
185-
SDPASimple(child.kv_cache, child.dim, child.head_dim, child.n_rep),
177+
SDPASimple(child.dim, child.head_dim, child.n_rep),
186178
)
187179
else:
188180
replace_sdpa_with_simple_sdpa(child)
@@ -195,7 +187,7 @@ def replace_sdpa_with_flex_sdpa(module: torch.nn.Module):
195187
setattr(
196188
module,
197189
name,
198-
SDPAFlex(child.kv_cache, child.dim, child.n_rep),
190+
SDPAFlex(child.dim, child.n_rep),
199191
)
200192
else:
201193
replace_sdpa_with_flex_sdpa(child)
@@ -227,13 +219,11 @@ class SDPACoreML(torch.nn.Module):
227219

228220
def __init__(
229221
self,
230-
kv_cache: KVCache,
231222
dim: int,
232223
head_dim: int,
233224
n_rep: int,
234225
):
235226
super().__init__()
236-
self.kv_cache = kv_cache
237227
self.dim = dim
238228
self.head_dim = head_dim
239229
self.n_rep = n_rep
@@ -248,11 +238,6 @@ def forward(
248238
seqlen,
249239
mask,
250240
):
251-
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
252-
k = k.transpose(1, 2)
253-
v = v.transpose(1, 2)
254-
255-
k, v = self.kv_cache.update(input_pos, k, v)
256241
attn_mask = mask[None, None, input_pos]
257242

258243
if self.n_rep > 1:
@@ -270,7 +255,7 @@ def replace_sdpa_with_coreml_sdpa(module: torch.nn.Module):
270255
setattr(
271256
module,
272257
name,
273-
SDPACoreML(child.kv_cache, child.dim, child.head_dim, child.n_rep),
258+
SDPACoreML(child.dim, child.head_dim, child.n_rep),
274259
)
275260
else:
276261
replace_sdpa_with_coreml_sdpa(child)
@@ -357,6 +342,9 @@ def __init__(
357342
def update(
358343
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
359344
) -> Tuple[torch.Tensor, torch.Tensor]:
345+
# can we combine this with KVCacheCoreML?
346+
k_val = k_val.transpose(1, 2)
347+
v_val = v_val.transpose(1, 2)
360348
k_out = torch.ops.aten.index_put_(self.past_k_caches, [None, input_pos], k_val)
361349
v_out = torch.ops.aten.index_put_(self.past_v_caches, [None, input_pos], v_val)
362350

0 commit comments

Comments
 (0)