Skip to content

Commit 305350d

Browse files
committed
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code sharing"
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
2 parents ed78ae3 + ee290d0 commit 305350d

File tree

6 files changed

+109
-195
lines changed

6 files changed

+109
-195
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def get_custom_quant_ios_dtype(
312312
"""
313313
This function is specific for llama inputs and outputs
314314
"""
315-
if node.op == "placeholder" and "attention_sdpa_kv_cache_past_" in node.name:
315+
if node.op == "placeholder" and "attention_kv_cache_past_" in node.name:
316316
return kv_dtype
317317

318318
# Tag index put node before copy node, because copy is a skipped node in qnn

examples/models/llama/llama_transformer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,15 +286,13 @@ def update(
286286
class SDPA(nn.Module):
287287
def __init__(
288288
self,
289-
kv_cache: KVCache,
290289
dim: int,
291290
head_dim: int,
292291
n_rep: int,
293292
max_seq_len: int,
294293
enable_dynamic_shape: bool,
295294
):
296295
super().__init__()
297-
self.kv_cache = kv_cache
298296
self.dim = dim
299297
self.head_dim = head_dim
300298
self.n_rep = n_rep
@@ -373,7 +371,6 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
373371
args.enable_dynamic_shape,
374372
)
375373
self.SDPA = SDPA(
376-
kv_cache=self.kv_cache,
377374
dim=self.n_local_heads * self.head_dim,
378375
head_dim=self.head_dim,
379376
n_rep=self.n_rep,
@@ -406,7 +403,7 @@ def forward(
406403

407404
if self.use_kv_cache:
408405
assert input_pos is not None
409-
k, v = self.SDPA.kv_cache.update(input_pos, k, v)
406+
k, v = self.kv_cache.update(input_pos, k, v)
410407
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
411408
return self.wo(output)
412409

examples/models/llama/source_transformation/attention_sink.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ def __init__(
111111
self,
112112
n_heads: int,
113113
head_dim: int,
114-
transpose_cache: bool,
115114
enable_dynamic_shape: bool,
116115
rope: RopeWithAttentionSink,
117116
window_size: int,
@@ -125,7 +124,6 @@ def __init__(
125124
max_seq_length=window_size + sink_size,
126125
n_heads=n_heads,
127126
head_dim=head_dim,
128-
transpose_cache=transpose_cache,
129127
enable_dynamic_shape=enable_dynamic_shape,
130128
dtype=dtype,
131129
)
@@ -161,28 +159,17 @@ def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int:
161159
input_pos_item + self.position_shift - self.sink_size - num_to_evict
162160
)
163161
num_empty_space = self.window_size - num_to_keep
164-
dim_to_slice = 2 if self.transpose_cache else 1
162+
dim_to_slice = 2
165163
k_to_keep = self.k_cache.narrow(
166164
dim_to_slice,
167165
self.sink_size + num_to_evict, # pyre-ignore [6]
168166
num_to_keep, # pyre-ignore [6]
169167
)
170-
if self.transpose_cache:
171-
k_to_keep = self.rope.rerotate_k(
172-
k=k_to_keep.transpose(1, 2),
173-
original_position=( # pyre-ignore [6]
174-
self.sink_size + num_to_evict
175-
),
176-
new_position=self.sink_size,
177-
).transpose(1, 2)
178-
else:
179-
k_to_keep = self.rope.rerotate_k(
180-
k=k_to_keep,
181-
original_position=( # pyre-ignore [6]
182-
self.sink_size + num_to_evict
183-
),
184-
new_position=self.sink_size,
185-
)
168+
k_to_keep = self.rope.rerotate_k(
169+
k=k_to_keep.transpose(1, 2),
170+
original_position=(self.sink_size + num_to_evict), # pyre-ignore [6]
171+
new_position=self.sink_size,
172+
).transpose(1, 2)
186173
self.k_cache = torch.cat(
187174
[
188175
self.k_cache.narrow(dim_to_slice, 0, self.sink_size),
@@ -278,7 +265,6 @@ def _replace_attention(
278265
kv_cache_with_attention_sink = KVCacheWithAttentionSink(
279266
n_heads=kv_cache.n_heads,
280267
head_dim=kv_cache.head_dim,
281-
transpose_cache=kv_cache.transpose_cache,
282268
enable_dynamic_shape=kv_cache.enable_dynamic_shape,
283269
rope=rope_with_attention_sink,
284270
max_batch_size=kv_cache.max_batch_size,

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,12 @@ def update(self, input_pos, k_val, v_val):
145145
)
146146

147147
start_pos = input_pos[0].item()
148-
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
149-
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
148+
if self.use_custom_update_cache_op:
149+
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
150+
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
151+
else:
152+
k_out[:, :, input_pos] = k_val
153+
v_out[:, :, input_pos] = v_val
150154

151155
return k_out.transpose(1, 2), v_out.transpose(1, 2)
152156

examples/models/llama/source_transformation/sdpa.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@
1919
class SDPACustom(torch.nn.Module):
2020
def __init__(
2121
self,
22-
kv_cache: KVCache,
2322
dim: int,
2423
):
2524
super().__init__()
26-
self.kv_cache = kv_cache
2725
self.dim = dim
2826

2927
def forward(
@@ -65,7 +63,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):
6563
setattr(
6664
module,
6765
name,
68-
SDPACustom(child.kv_cache, child.dim),
66+
SDPACustom(child.dim),
6967
)
7068
else:
7169
_replace_sdpa_with_custom_op(child)
@@ -81,13 +79,11 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
8179
class SDPASimple(torch.nn.Module):
8280
def __init__(
8381
self,
84-
kv_cache: KVCache,
8582
dim: int,
8683
head_dim: int,
8784
n_rep: int,
8885
):
8986
super().__init__()
90-
self.kv_cache = kv_cache
9187
self.dim = dim
9288
self.head_dim = head_dim
9389
self.n_rep = n_rep
@@ -135,12 +131,10 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
135131
class SDPAFlex(torch.nn.Module):
136132
def __init__(
137133
self,
138-
kv_cache: KVCache,
139134
dim: int,
140135
n_rep: int,
141136
):
142137
super().__init__()
143-
self.kv_cache = kv_cache
144138
self.dim = dim
145139
self.n_rep = n_rep
146140

@@ -177,7 +171,7 @@ def replace_sdpa_with_simple_sdpa(module: torch.nn.Module):
177171
setattr(
178172
module,
179173
name,
180-
SDPASimple(child.kv_cache, child.dim, child.head_dim, child.n_rep),
174+
SDPASimple(child.dim, child.head_dim, child.n_rep),
181175
)
182176
else:
183177
replace_sdpa_with_simple_sdpa(child)
@@ -190,7 +184,7 @@ def replace_sdpa_with_flex_sdpa(module: torch.nn.Module):
190184
setattr(
191185
module,
192186
name,
193-
SDPAFlex(child.kv_cache, child.dim, child.n_rep),
187+
SDPAFlex(child.dim, child.n_rep),
194188
)
195189
else:
196190
replace_sdpa_with_flex_sdpa(child)
@@ -222,13 +216,11 @@ class SDPACoreML(torch.nn.Module):
222216

223217
def __init__(
224218
self,
225-
kv_cache: KVCache,
226219
dim: int,
227220
head_dim: int,
228221
n_rep: int,
229222
):
230223
super().__init__()
231-
self.kv_cache = kv_cache
232224
self.dim = dim
233225
self.head_dim = head_dim
234226
self.n_rep = n_rep
@@ -260,7 +252,7 @@ def replace_sdpa_with_coreml_sdpa(module: torch.nn.Module):
260252
setattr(
261253
module,
262254
name,
263-
SDPACoreML(child.kv_cache, child.dim, child.head_dim, child.n_rep),
255+
SDPACoreML(child.dim, child.head_dim, child.n_rep),
264256
)
265257
else:
266258
replace_sdpa_with_coreml_sdpa(child)

0 commit comments

Comments
 (0)