@@ -474,59 +474,63 @@ def forward(
474474
475475 if self .decode_kv_cache_as_io :
476476 assert self .use_kv_cache
477- # mask = self.mask[None, None, input_pos]
477+ mask = attn_mask
478478 if self .use_additive_kv_cache_update :
479- assert cache_pos_mask is not None
480479 assert seqlen == 1
481- k_update = cache_pos_mask * k
482- v_update = cache_pos_mask * v
483- k = k_cache + k_update
484- v = v_cache + v_update
485- assert k .shape == k_cache .shape
486- assert v .shape == v_cache .shape
487-
488- # # Attempt 1 to use torch.cat:
489- # # This fails to lower to ET during to_executorch due to a dynamo error related to the
490- # # delegate call. We can talk to compiler about this, but the bigger issue is although
491- # # the CoreML mlpackage lowers, it fails at runtime on CPU/ANE with "input data broken / unsupported
492- # # model (model code -7)". It does run on GPU.
493- # # I suspect it is related to the data-dependent / dynamic shape of k, v, and mask
494-
495- # buffer = 2 # needed to make dynamo happy
496- # input_pos_item = input_pos[0].item()
497- # torch._check_is_size(input_pos_item)
498- # torch._check(input_pos_item + seqlen <= self.max_seq_len - buffer)
499- # mask = torch.narrow(mask, dim=3, start=0, length=input_pos_item + seqlen)
500-
501- # k = torch.cat([torch.narrow(k_cache, dim=2, start=0, length=input_pos_item), k], axis=2)
502- # v = torch.cat([torch.narrow(v_cache, dim=2, start=0, length=input_pos_item), v], axis=2)
503-
504- # # Attempt 2 to use torch.cat
505- # # Dynamo fails with "expand: attempting to expand a dimension of length u0 + 1024!"
506- # # I'm not confident this variant will work in CoreML if we can export it, though.
480+ # assert cache_pos_mask is not None
481+ # k_update = cache_pos_mask * k
482+ # v_update = cache_pos_mask * v
483+ # print("k_update", k_update.shape)
484+ # print("k_cache", k_cache.shape)
485+ # k = k_cache + k_update
486+ # v = v_cache + v_update
487+ # assert k.shape == k_cache.shape
488+ # assert v.shape == v_cache.shape
489+
490+ # Attempt 1 to use torch.cat:
491+ # This fails to lower to ET during to_executorch due to a dynamo error related to the
492+ # delegate call. We can talk to compiler about this, but the bigger issue is although
493+ # the CoreML mlpackage lowers, it fails at runtime on CPU/ANE with "input data broken / unsupported
494+ # model (model code -7)". It does run on GPU.
495+ # I suspect it is related to the data-dependent / dynamic shape of k, v, and mask
496+
497+ buffer = 2 # needed to make dynamo happy
498+ torch ._check (input_pos + seqlen <= self .max_seq_len - buffer )
499+ mask = torch .narrow (mask , dim = 1 , start = 0 , length = input_pos + seqlen )
500+
501+ k = torch .cat (
502+ [torch .narrow (k_cache , dim = 2 , start = 0 , length = input_pos ), k ], axis = 2
503+ )
504+ v = torch .cat (
505+ [torch .narrow (v_cache , dim = 2 , start = 0 , length = input_pos ), v ], axis = 2
506+ )
507+
508+ # # # Attempt 2 to use torch.cat
509+ # # # Dynamo fails with "expand: attempting to expand a dimension of length u0 + 1024!"
510+ # # # I'm not confident this variant will work in CoreML if we can export it, though.
507511 # buffer = 2
508- # input_pos_item = input_pos[0].item()
509- # torch._check_is_size(input_pos_item)
510- # torch._check(input_pos_item + seqlen <= self.max_seq_len - buffer)
512+ # # input_pos_item = input_pos[0].item()
513+ # # torch._check_is_size(input_pos_item)
514+ # torch._check(input_pos + seqlen <= self.max_seq_len - buffer)
511515
512- # k = torch.cat([torch.narrow(k_cache, dim=2, start=0, length=input_pos_item), k], axis=2)
513- # k = k.expand(k_cache.size())
514- # v = torch.cat([torch.narrow(v_cache, dim=2, start=0, length=input_pos_item), v], axis=2)
516+ # k = torch.cat([torch.narrow(k_cache, dim=2, start=0, length=input_pos), k], axis=2)
517+
518+ # # torch.Size([1, 12, 1, 64]) torch.Size([1, 12, 1024, 64]) torch.Size([1, 12, 1024, 64])
519+ # k = k.expand(k_cache.size()) # torch.Size([1, 12, 1024, 64])
520+ # v = torch.cat([torch.narrow(v_cache, dim=2, start=0, length=input_pos), v], axis=2)
515521 # v = v.expand(v_cache.size())
516522 else :
517523 k = torch .ops .aten .index_put (k_cache , [None , None , input_pos , None ], k )
518524 v = torch .ops .aten .index_put (v_cache , [None , None , input_pos , None ], v )
519525 else :
520526 assert not self .use_kv_cache
521- # assert hasattr(self, "mask")
522-
523- # mask = self.mask[:seqlen, :seqlen]
527+ mask = attn_mask
524528
525529 # grouped multiquery attention: expand out keys and values
526530 if self .n_rep > 1 :
527531 k = k .repeat_interleave (self .n_rep , dim = 1 )
528532 v = v .repeat_interleave (self .n_rep , dim = 1 )
529- output = torch .ops .coreml .sdpa (q , k , v , attn_mask )
533+ output = torch .ops .coreml .sdpa (q , k , v , mask )
530534
531535 output = output .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , - 1 )
532536
0 commit comments