@@ -196,11 +196,11 @@ def update(
196196            # narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) 
197197
198198            if  self .transpose_cache :
199-                 narrowed_k  =  self .k_cache [:, :, input_pos : (input_pos + seq_length ), :]
200-                 narrowed_v  =  self .v_cache [:, :, input_pos : (input_pos + seq_length ), :]
199+                 narrowed_k  =  self .k_cache [:, :, input_pos  :  (input_pos   +   seq_length ), :]
200+                 narrowed_v  =  self .v_cache [:, :, input_pos  :  (input_pos   +   seq_length ), :]
201201            else :
202-                 narrowed_k  =  self .k_cache [:, input_pos : (input_pos + seq_length ), :, :]
203-                 narrowed_v  =  self .v_cache [:, input_pos : (input_pos + seq_length ), :, :]
202+                 narrowed_k  =  self .k_cache [:, input_pos  :  (input_pos   +   seq_length ), :, :]
203+                 narrowed_v  =  self .v_cache [:, input_pos  :  (input_pos   +   seq_length ), :, :]
204204
205205            narrowed_k .copy_ (k_val )
206206            narrowed_v .copy_ (v_val )
@@ -257,7 +257,8 @@ def forward(
257257            torch ._check (start_pos  <  self .max_seq_len )
258258            seq_length  =  q .size (2 )
259259            # pyre-ignore: Incompatible parameter type [6] 
260-             attn_mask  =  mask .narrow (0 , start_pos , seq_length )
260+             # attn_mask = mask.narrow(0, start_pos, seq_length) 
261+             attn_mask  =  mask [start_pos  : (start_pos  +  seq_length )]
261262        else :
262263            attn_mask  =  mask [None , None , input_pos ]
263264
@@ -518,9 +519,13 @@ def forward(
518519                torch ._check_is_size (input_pos_item )
519520                torch ._check (input_pos_item  <  self .params .max_seq_len )
520521                # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor 
521-                 freqs_cos  =  self .freqs_cos [input_pos_item :(input_pos_item  +  seqlen )] #.narrow(0, input_pos_item, seqlen) 
522+                 freqs_cos  =  self .freqs_cos [
523+                     input_pos_item  : (input_pos_item  +  seqlen )
524+                 ]  # .narrow(0, input_pos_item, seqlen) 
522525                # pyre-ignore: Incompatible parameter type [6] 
523-                 freqs_sin  =  self .freqs_sin [input_pos_item :(input_pos_item  +  seqlen )] #.narrow(0, input_pos_item, seqlen) 
526+                 freqs_sin  =  self .freqs_sin [
527+                     input_pos_item  : (input_pos_item  +  seqlen )
528+                 ]  # .narrow(0, input_pos_item, seqlen) 
524529            else :
525530                # When not using dynamic shape, use of the .item results in 
526531                # symints, due to querying the data from tensor. 
0 commit comments