@@ -479,6 +479,37 @@ def forward(
479479 v = v_cache + v_update
480480 assert k .shape == k_cache .shape
481481 assert v .shape == v_cache .shape
482+
483+
484+ # # Attempt 1 to use torch.cat:
485+ # # This fails to lower to ET during to_executorch due to a dynamo error related to the
486+ # # delegate call. We can talk to compiler about this, but the bigger issue is although
487+ # # the CoreML mlpackage lowers, it fails at runtime on CPU/ANE with "input data broken / unsupported
488+ # # model (model code -7)". It does run on GPU.
489+ # # I suspect it is related to the data-dependent / dynamic shape of k, v, and mask
490+
491+ # buffer = 2 # needed to make dynamo happy
492+ # input_pos_item = input_pos[0].item()
493+ # torch._check_is_size(input_pos_item)
494+ # torch._check(input_pos_item + seqlen <= self.max_seq_len - buffer)
495+ # mask = torch.narrow(mask, dim=3, start=0, length=input_pos_item + seqlen)
496+
497+ # k = torch.cat([torch.narrow(k_cache, dim=2, start=0, length=input_pos_item), k], axis=2)
498+ # v = torch.cat([torch.narrow(v_cache, dim=2, start=0, length=input_pos_item), v], axis=2)
499+
500+
501+ # # Attempt 2 to use torch.cat
502+ # # Dynamo fails with "expand: attempting to expand a dimension of length u0 + 1024!"
503+ # # I'm not confident this variant will work in CoreML if we can export it, though.
504+ # buffer = 2
505+ # input_pos_item = input_pos[0].item()
506+ # torch._check_is_size(input_pos_item)
507+ # torch._check(input_pos_item + seqlen <= self.max_seq_len - buffer)
508+
509+ # k = torch.cat([torch.narrow(k_cache, dim=2, start=0, length=input_pos_item), k], axis=2)
510+ # k = k.expand(k_cache.size())
511+ # v = torch.cat([torch.narrow(v_cache, dim=2, start=0, length=input_pos_item), v], axis=2)
512+ # v = v.expand(v_cache.size())
482513 else :
483514 k = torch .ops .aten .index_put (k_cache , [None , None , input_pos , None ], k )
484515 v = torch .ops .aten .index_put (v_cache , [None , None , input_pos , None ], v )
0 commit comments