Skip to content

Commit 9fd3d12

Browse files
committed
[ExecuTorch][BE] Split kv cache and SDPA for better code sharing
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress ghstack-source-id: abaea2c Pull Request resolved: #7413
1 parent f4e77c7 commit 9fd3d12

File tree

17 files changed

+511
-382
lines changed

17 files changed

+511
-382
lines changed

.ci/scripts/test_llama.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ fi
112112

113113
if [[ "${MODE}" =~ .*quantize_kv.* ]]; then
114114
QUANTIZE_KV_CACHE=ON
115+
# quantize_kv cache transform uses custom kv cache update op
116+
CUSTOM=ON
115117
else
116118
QUANTIZE_KV_CACHE=OFF
117119
fi

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def get_custom_quant_ios_dtype(
374374
"""
375375
This function is specific for llama inputs and outputs
376376
"""
377-
if node.op == "placeholder" and "attention_sdpa_kv_cache_past_" in node.name:
377+
if node.op == "placeholder" and "attention_kv_cache_past_" in node.name:
378378
return kv_dtype
379379

380380
# Tag index put node before copy node, because copy is a skipped node in qnn

examples/models/llama/export_llama_lib.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,8 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
665665
# export_to_edge
666666
builder_exported = _prepare_for_llama_export(args).export()
667667

668+
builder_exported.run_canonical_optimizations()
669+
668670
if args.export_only:
669671
exit()
670672

examples/models/llama/llama_transformer.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -232,22 +232,16 @@ def __init__(
232232
max_seq_length: int,
233233
n_heads: int,
234234
head_dim: int,
235-
transpose_cache: bool,
236235
enable_dynamic_shape: bool,
237236
dtype=torch.float32,
238237
):
239238
super().__init__()
240239
self.max_seq_length = max_seq_length
241-
self.is_transposed = transpose_cache
242-
if transpose_cache:
243-
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
244-
else:
245-
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
240+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
246241

247242
self.max_batch_size = max_batch_size
248243
self.n_heads = n_heads
249244
self.head_dim = head_dim
250-
self.transpose_cache = transpose_cache
251245
self.enable_dynamic_shape = enable_dynamic_shape
252246
self.register_buffer(
253247
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
@@ -259,12 +253,12 @@ def __init__(
259253
def update(
260254
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
261255
) -> Tuple[torch.Tensor, torch.Tensor]:
262-
# input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
256+
# input_pos: [S], k_val: [B, H, S, D]
263257
if self.enable_dynamic_shape:
264258
start_pos = input_pos[0].item()
265259
torch._check_is_size(start_pos)
266260
torch._check(start_pos < self.max_seq_length)
267-
dim_to_slice = 2 if self.transpose_cache else 1
261+
dim_to_slice = 2
268262
seq_length = k_val.size(dim_to_slice)
269263
# Replace the entry in the cache for this token
270264
# The following lines are equivalent to:
@@ -283,28 +277,22 @@ def update(
283277
else:
284278
k_out = self.k_cache
285279
v_out = self.v_cache
286-
if self.transpose_cache:
287-
k_out[:, :, input_pos] = k_val
288-
v_out[:, :, input_pos] = v_val
289-
else:
290-
k_out[:, input_pos] = k_val
291-
v_out[:, input_pos] = v_val
280+
k_out[:, :, input_pos] = k_val
281+
v_out[:, :, input_pos] = v_val
292282

293283
return k_out, v_out
294284

295285

296286
class SDPA(nn.Module):
297287
def __init__(
298288
self,
299-
kv_cache: KVCache,
300289
dim: int,
301290
head_dim: int,
302291
n_rep: int,
303292
max_seq_len: int,
304293
enable_dynamic_shape: bool,
305294
):
306295
super().__init__()
307-
self.kv_cache = kv_cache
308296
self.dim = dim
309297
self.head_dim = head_dim
310298
self.n_rep = n_rep
@@ -314,18 +302,13 @@ def __init__(
314302
def forward(
315303
self,
316304
input_pos: torch.Tensor,
317-
q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim)
318-
k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim)
319-
v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim)
305+
q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim)
306+
k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim)
307+
v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim)
320308
bsz,
321309
seqlen,
322310
mask: torch.Tensor,
323311
) -> torch.Tensor:
324-
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
325-
k = k.transpose(1, 2)
326-
v = v.transpose(1, 2)
327-
328-
k, v = self.kv_cache.update(input_pos, k, v)
329312
if self.enable_dynamic_shape:
330313
start_pos = input_pos[-1].item()
331314
torch._check_is_size(start_pos)
@@ -336,6 +319,8 @@ def forward(
336319
else:
337320
attn_mask = mask[None, None, input_pos]
338321

322+
# TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
323+
# can natively support GQA now. But needs enable_gqa=True
339324
k = k.repeat_interleave(self.n_rep, dim=1)
340325
v = v.repeat_interleave(self.n_rep, dim=1)
341326
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
@@ -383,11 +368,9 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
383368
args.max_seq_len,
384369
self.n_kv_heads,
385370
self.head_dim,
386-
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v
387371
args.enable_dynamic_shape,
388372
)
389373
self.SDPA = SDPA(
390-
kv_cache=self.kv_cache,
391374
dim=self.n_local_heads * self.head_dim,
392375
head_dim=self.head_dim,
393376
n_rep=self.n_rep,
@@ -414,15 +397,16 @@ def forward(
414397
# RoPE relative positional embeddings
415398
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
416399

400+
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
401+
k = k.transpose(1, 2)
402+
v = v.transpose(1, 2)
403+
417404
if self.use_kv_cache:
418405
assert input_pos is not None
406+
k, v = self.kv_cache.update(input_pos, k, v)
419407
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
420408
return self.wo(output)
421409

422-
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
423-
k = k.transpose(1, 2)
424-
v = v.transpose(1, 2)
425-
426410
# grouped multiquery attention: expand out keys and values
427411
k = k.repeat_interleave(self.n_rep, dim=1)
428412
v = v.repeat_interleave(self.n_rep, dim=1)

examples/models/llama/source_transformation/attention_sink.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ def __init__(
111111
self,
112112
n_heads: int,
113113
head_dim: int,
114-
transpose_cache: bool,
115114
enable_dynamic_shape: bool,
116115
rope: RopeWithAttentionSink,
117116
window_size: int,
@@ -125,7 +124,6 @@ def __init__(
125124
max_seq_length=window_size + sink_size,
126125
n_heads=n_heads,
127126
head_dim=head_dim,
128-
transpose_cache=transpose_cache,
129127
enable_dynamic_shape=enable_dynamic_shape,
130128
dtype=dtype,
131129
)
@@ -161,28 +159,17 @@ def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int:
161159
input_pos_item + self.position_shift - self.sink_size - num_to_evict
162160
)
163161
num_empty_space = self.window_size - num_to_keep
164-
dim_to_slice = 2 if self.transpose_cache else 1
162+
dim_to_slice = 2
165163
k_to_keep = self.k_cache.narrow(
166164
dim_to_slice,
167165
self.sink_size + num_to_evict, # pyre-ignore [6]
168166
num_to_keep, # pyre-ignore [6]
169167
)
170-
if self.transpose_cache:
171-
k_to_keep = self.rope.rerotate_k(
172-
k=k_to_keep.transpose(1, 2),
173-
original_position=( # pyre-ignore [6]
174-
self.sink_size + num_to_evict
175-
),
176-
new_position=self.sink_size,
177-
).transpose(1, 2)
178-
else:
179-
k_to_keep = self.rope.rerotate_k(
180-
k=k_to_keep,
181-
original_position=( # pyre-ignore [6]
182-
self.sink_size + num_to_evict
183-
),
184-
new_position=self.sink_size,
185-
)
168+
k_to_keep = self.rope.rerotate_k(
169+
k=k_to_keep.transpose(1, 2),
170+
original_position=(self.sink_size + num_to_evict), # pyre-ignore [6]
171+
new_position=self.sink_size,
172+
).transpose(1, 2)
186173
self.k_cache = torch.cat(
187174
[
188175
self.k_cache.narrow(dim_to_slice, 0, self.sink_size),
@@ -278,7 +265,6 @@ def _replace_attention(
278265
kv_cache_with_attention_sink = KVCacheWithAttentionSink(
279266
n_heads=kv_cache.n_heads,
280267
head_dim=kv_cache.head_dim,
281-
transpose_cache=kv_cache.transpose_cache,
282268
enable_dynamic_shape=kv_cache.enable_dynamic_shape,
283269
rope=rope_with_attention_sink,
284270
max_batch_size=kv_cache.max_batch_size,
@@ -288,7 +274,6 @@ def _replace_attention(
288274
dtype=kv_cache.k_cache.dtype,
289275
)
290276
child_module.kv_cache = kv_cache_with_attention_sink
291-
child_module.SDPA.kv_cache = kv_cache_with_attention_sink
292277
child_module.forward = types.MethodType( # pyre-ignore
293278
attention_sink_forward, child_module
294279
)

0 commit comments

Comments
 (0)