@@ -474,7 +474,6 @@ def forward(
474474
475475 if self .decode_kv_cache_as_io :
476476 assert self .use_kv_cache
477- mask = attn_mask
478477 if self .use_additive_kv_cache_update :
479478 assert seqlen == 1
480479 # assert cache_pos_mask is not None
@@ -494,16 +493,12 @@ def forward(
494493 # model (model code -7)". It does run on GPU.
495494 # I suspect it is related to the data-dependent / dynamic shape of k, v, and mask
496495
497- buffer = 2 # needed to make dynamo happy
498- torch ._check (input_pos + seqlen <= self .max_seq_len - buffer )
499- mask = torch .narrow (mask , dim = 1 , start = 0 , length = input_pos + seqlen )
496+ # buffer = 2 # needed to make dynamo happy
497+ # torch._check(input_pos + seqlen <= self.max_seq_len - buffer)
498+ # mask = torch.narrow(mask, dim=1, start=0, length=input_pos + seqlen)
500499
501- k = torch .cat (
502- [torch .narrow (k_cache , dim = 2 , start = 0 , length = input_pos ), k ], axis = 2
503- )
504- v = torch .cat (
505- [torch .narrow (v_cache , dim = 2 , start = 0 , length = input_pos ), v ], axis = 2
506- )
500+ k = torch .cat ([k_cache , k ], axis = 2 )
501+ v = torch .cat ([v_cache , v ], axis = 2 )
507502
508503 # # # Attempt 2 to use torch.cat
509504 # # # Dynamo fails with "expand: attempting to expand a dimension of length u0 + 1024!"
@@ -524,13 +519,12 @@ def forward(
524519 v = torch .ops .aten .index_put (v_cache , [None , None , input_pos , None ], v )
525520 else :
526521 assert not self .use_kv_cache
527- mask = attn_mask
528522
529523 # grouped multiquery attention: expand out keys and values
530524 if self .n_rep > 1 :
531525 k = k .repeat_interleave (self .n_rep , dim = 1 )
532526 v = v .repeat_interleave (self .n_rep , dim = 1 )
533- output = torch .ops .coreml .sdpa (q , k , v , mask )
527+ output = torch .ops .coreml .sdpa (q , k , v , attn_mask )
534528
535529 output = output .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , - 1 )
536530
@@ -680,8 +674,8 @@ def __init__(self, params: ModelArgs):
680674 self .max_seq_len = params .max_seq_len
681675 causal_mask = torch .tril (
682676 torch .ones (
683- self .max_seq_len ,
684- self .max_seq_len ,
677+ self .max_seq_len + 1 ,
678+ self .max_seq_len + 1 ,
685679 dtype = torch .float16 ,
686680 device = "cpu" ,
687681 )
0 commit comments