Skip to content

Commit 18cddd5

Browse files
committed
concat KV cache
1 parent 3f75e3d commit 18cddd5

File tree

1 file changed

+42
-38
lines changed

1 file changed

+42
-38
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -474,59 +474,63 @@ def forward(
474474

475475
if self.decode_kv_cache_as_io:
476476
assert self.use_kv_cache
477-
# mask = self.mask[None, None, input_pos]
477+
mask = attn_mask
478478
if self.use_additive_kv_cache_update:
479-
assert cache_pos_mask is not None
480479
assert seqlen == 1
481-
k_update = cache_pos_mask * k
482-
v_update = cache_pos_mask * v
483-
k = k_cache + k_update
484-
v = v_cache + v_update
485-
assert k.shape == k_cache.shape
486-
assert v.shape == v_cache.shape
487-
488-
# # Attempt 1 to use torch.cat:
489-
# # This fails to lower to ET during to_executorch due to a dynamo error related to the
490-
# # delegate call. We can talk to compiler about this, but the bigger issue is although
491-
# # the CoreML mlpackage lowers, it fails at runtime on CPU/ANE with "input data broken / unsupported
492-
# # model (model code -7)". It does run on GPU.
493-
# # I suspect it is related to the data-dependent / dynamic shape of k, v, and mask
494-
495-
# buffer = 2 # needed to make dynamo happy
496-
# input_pos_item = input_pos[0].item()
497-
# torch._check_is_size(input_pos_item)
498-
# torch._check(input_pos_item + seqlen <= self.max_seq_len - buffer)
499-
# mask = torch.narrow(mask, dim=3, start=0, length=input_pos_item + seqlen)
500-
501-
# k = torch.cat([torch.narrow(k_cache, dim=2, start=0, length=input_pos_item), k], axis=2)
502-
# v = torch.cat([torch.narrow(v_cache, dim=2, start=0, length=input_pos_item), v], axis=2)
503-
504-
# # Attempt 2 to use torch.cat
505-
# # Dynamo fails with "expand: attempting to expand a dimension of length u0 + 1024!"
506-
# # I'm not confident this variant will work in CoreML if we can export it, though.
480+
# assert cache_pos_mask is not None
481+
# k_update = cache_pos_mask * k
482+
# v_update = cache_pos_mask * v
483+
# print("k_update", k_update.shape)
484+
# print("k_cache", k_cache.shape)
485+
# k = k_cache + k_update
486+
# v = v_cache + v_update
487+
# assert k.shape == k_cache.shape
488+
# assert v.shape == v_cache.shape
489+
490+
# Attempt 1 to use torch.cat:
491+
# This fails to lower to ET during to_executorch due to a dynamo error related to the
492+
# delegate call. We can talk to compiler about this, but the bigger issue is although
493+
# the CoreML mlpackage lowers, it fails at runtime on CPU/ANE with "input data broken / unsupported
494+
# model (model code -7)". It does run on GPU.
495+
# I suspect it is related to the data-dependent / dynamic shape of k, v, and mask
496+
497+
buffer = 2 # needed to make dynamo happy
498+
torch._check(input_pos + seqlen <= self.max_seq_len - buffer)
499+
mask = torch.narrow(mask, dim=1, start=0, length=input_pos + seqlen)
500+
501+
k = torch.cat(
502+
[torch.narrow(k_cache, dim=2, start=0, length=input_pos), k], axis=2
503+
)
504+
v = torch.cat(
505+
[torch.narrow(v_cache, dim=2, start=0, length=input_pos), v], axis=2
506+
)
507+
508+
# # # Attempt 2 to use torch.cat
509+
# # # Dynamo fails with "expand: attempting to expand a dimension of length u0 + 1024!"
510+
# # # I'm not confident this variant will work in CoreML if we can export it, though.
507511
# buffer = 2
508-
# input_pos_item = input_pos[0].item()
509-
# torch._check_is_size(input_pos_item)
510-
# torch._check(input_pos_item + seqlen <= self.max_seq_len - buffer)
512+
# # input_pos_item = input_pos[0].item()
513+
# # torch._check_is_size(input_pos_item)
514+
# torch._check(input_pos + seqlen <= self.max_seq_len - buffer)
511515

512-
# k = torch.cat([torch.narrow(k_cache, dim=2, start=0, length=input_pos_item), k], axis=2)
513-
# k = k.expand(k_cache.size())
514-
# v = torch.cat([torch.narrow(v_cache, dim=2, start=0, length=input_pos_item), v], axis=2)
516+
# k = torch.cat([torch.narrow(k_cache, dim=2, start=0, length=input_pos), k], axis=2)
517+
518+
# # torch.Size([1, 12, 1, 64]) torch.Size([1, 12, 1024, 64]) torch.Size([1, 12, 1024, 64])
519+
# k = k.expand(k_cache.size()) # torch.Size([1, 12, 1024, 64])
520+
# v = torch.cat([torch.narrow(v_cache, dim=2, start=0, length=input_pos), v], axis=2)
515521
# v = v.expand(v_cache.size())
516522
else:
517523
k = torch.ops.aten.index_put(k_cache, [None, None, input_pos, None], k)
518524
v = torch.ops.aten.index_put(v_cache, [None, None, input_pos, None], v)
519525
else:
520526
assert not self.use_kv_cache
521-
# assert hasattr(self, "mask")
522-
523-
# mask = self.mask[:seqlen, :seqlen]
527+
mask = attn_mask
524528

525529
# grouped multiquery attention: expand out keys and values
526530
if self.n_rep > 1:
527531
k = k.repeat_interleave(self.n_rep, dim=1)
528532
v = v.repeat_interleave(self.n_rep, dim=1)
529-
output = torch.ops.coreml.sdpa(q, k, v, attn_mask)
533+
output = torch.ops.coreml.sdpa(q, k, v, mask)
530534

531535
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
532536

0 commit comments

Comments
 (0)