Skip to content

Commit dae8f94

Browse files
committed
up
1 parent c962278 commit dae8f94

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,37 @@ def forward(
479479
v = v_cache + v_update
480480
assert k.shape == k_cache.shape
481481
assert v.shape == v_cache.shape
482+
483+
484+
# # Attempt 1 to use torch.cat:
485+
# # This fails to lower to ET during to_executorch due to a dynamo error related to the
486+
# # delegate call. We can talk to compiler about this, but the bigger issue is although
487+
# # the CoreML mlpackage lowers, it fails at runtime on CPU/ANE with "input data broken / unsupported
488+
# # model (model code -7)". It does run on GPU.
489+
# # I suspect it is related to the data-dependent / dynamic shape of k, v, and mask
490+
491+
# buffer = 2 # needed to make dynamo happy
492+
# input_pos_item = input_pos[0].item()
493+
# torch._check_is_size(input_pos_item)
494+
# torch._check(input_pos_item + seqlen <= self.max_seq_len - buffer)
495+
# mask = torch.narrow(mask, dim=3, start=0, length=input_pos_item + seqlen)
496+
497+
# k = torch.cat([torch.narrow(k_cache, dim=2, start=0, length=input_pos_item), k], axis=2)
498+
# v = torch.cat([torch.narrow(v_cache, dim=2, start=0, length=input_pos_item), v], axis=2)
499+
500+
501+
# # Attempt 2 to use torch.cat
502+
# # Dynamo fails with "expand: attempting to expand a dimension of length u0 + 1024!"
503+
# # I'm not confident this variant will work in CoreML if we can export it, though.
504+
# buffer = 2
505+
# input_pos_item = input_pos[0].item()
506+
# torch._check_is_size(input_pos_item)
507+
# torch._check(input_pos_item + seqlen <= self.max_seq_len - buffer)
508+
509+
# k = torch.cat([torch.narrow(k_cache, dim=2, start=0, length=input_pos_item), k], axis=2)
510+
# k = k.expand(k_cache.size())
511+
# v = torch.cat([torch.narrow(v_cache, dim=2, start=0, length=input_pos_item), v], axis=2)
512+
# v = v.expand(v_cache.size())
482513
else:
483514
k = torch.ops.aten.index_put(k_cache, [None, None, input_pos, None], k)
484515
v = torch.ops.aten.index_put(v_cache, [None, None, input_pos, None], v)

0 commit comments

Comments
 (0)