@@ -191,9 +191,16 @@ def update(
191191 # when dim_to_slice is 1
192192 # We use .narrow() here to make the compiler happy
193193 # pyre-ignore: Incompatible parameter type [6]
194- narrowed_k = self .k_cache .narrow (dim_to_slice , start_pos , seq_length )
195- # pyre-ignore: Incompatible parameter type [6]
196- narrowed_v = self .v_cache .narrow (dim_to_slice , start_pos , seq_length )
194+ # narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
195+ # # pyre-ignore: Incompatible parameter type [6]
196+ # narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
197+
198+ if self .transpose_cache :
199+ narrowed_k = self .k_cache [:, :, input_pos :(input_pos + seq_length ), :]
200+ narrowed_v = self .v_cache [:, :, input_pos :(input_pos + seq_length ), :]
201+ else :
202+ narrowed_k = self .k_cache [:, input_pos :(input_pos + seq_length ), :, :]
203+ narrowed_v = self .v_cache [:, input_pos :(input_pos + seq_length ), :, :]
197204
198205 narrowed_k .copy_ (k_val )
199206 narrowed_v .copy_ (v_val )
@@ -511,9 +518,9 @@ def forward(
511518 torch ._check_is_size (input_pos_item )
512519 torch ._check (input_pos_item < self .params .max_seq_len )
513520 # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
514- freqs_cos = self .freqs_cos .narrow (0 , input_pos_item , seqlen )
521+ freqs_cos = self .freqs_cos [ input_pos_item :( input_pos_item + seqlen )] # .narrow(0, input_pos_item, seqlen)
515522 # pyre-ignore: Incompatible parameter type [6]
516- freqs_sin = self .freqs_sin .narrow (0 , input_pos_item , seqlen )
523+ freqs_sin = self .freqs_sin [ input_pos_item :( input_pos_item + seqlen )] # .narrow(0, input_pos_item, seqlen)
517524 else :
518525 # When not using dynamic shape, use of the .item results in
519526 # symints, due to querying the data from tensor.
0 commit comments