Skip to content

Commit 2597994

Browse files
committed
Concat cache without narrow
1 parent 92b84b1 commit 2597994

File tree

1 file changed

+8
-14
lines changed

1 file changed

+8
-14
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,6 @@ def forward(
474474

475475
if self.decode_kv_cache_as_io:
476476
assert self.use_kv_cache
477-
mask = attn_mask
478477
if self.use_additive_kv_cache_update:
479478
assert seqlen == 1
480479
# assert cache_pos_mask is not None
@@ -494,16 +493,12 @@ def forward(
494493
# model (model code -7)". It does run on GPU.
495494
# I suspect it is related to the data-dependent / dynamic shape of k, v, and mask
496495

497-
buffer = 2 # needed to make dynamo happy
498-
torch._check(input_pos + seqlen <= self.max_seq_len - buffer)
499-
mask = torch.narrow(mask, dim=1, start=0, length=input_pos + seqlen)
496+
# buffer = 2 # needed to make dynamo happy
497+
# torch._check(input_pos + seqlen <= self.max_seq_len - buffer)
498+
# mask = torch.narrow(mask, dim=1, start=0, length=input_pos + seqlen)
500499

501-
k = torch.cat(
502-
[torch.narrow(k_cache, dim=2, start=0, length=input_pos), k], axis=2
503-
)
504-
v = torch.cat(
505-
[torch.narrow(v_cache, dim=2, start=0, length=input_pos), v], axis=2
506-
)
500+
k = torch.cat([k_cache, k], axis=2)
501+
v = torch.cat([v_cache, v], axis=2)
507502

508503
# # # Attempt 2 to use torch.cat
509504
# # # Dynamo fails with "expand: attempting to expand a dimension of length u0 + 1024!"
@@ -524,13 +519,12 @@ def forward(
524519
v = torch.ops.aten.index_put(v_cache, [None, None, input_pos, None], v)
525520
else:
526521
assert not self.use_kv_cache
527-
mask = attn_mask
528522

529523
# grouped multiquery attention: expand out keys and values
530524
if self.n_rep > 1:
531525
k = k.repeat_interleave(self.n_rep, dim=1)
532526
v = v.repeat_interleave(self.n_rep, dim=1)
533-
output = torch.ops.coreml.sdpa(q, k, v, mask)
527+
output = torch.ops.coreml.sdpa(q, k, v, attn_mask)
534528

535529
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
536530

@@ -680,8 +674,8 @@ def __init__(self, params: ModelArgs):
680674
self.max_seq_len = params.max_seq_len
681675
causal_mask = torch.tril(
682676
torch.ones(
683-
self.max_seq_len,
684-
self.max_seq_len,
677+
self.max_seq_len + 1,
678+
self.max_seq_len + 1,
685679
dtype=torch.float16,
686680
device="cpu",
687681
)

0 commit comments

Comments
 (0)