@@ -179,9 +179,9 @@ def update(
179179 ) -> Tuple [torch .Tensor , torch .Tensor ]:
180180 # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
181181 if self .enable_dynamic_shape :
182- start_pos = input_pos [0 ].item ()
183- torch ._check_is_size (start_pos )
184- torch ._check (start_pos < self .max_seq_length )
182+ # start_pos = input_pos[0].item()
183+ # torch._check_is_size(start_pos)
184+ # torch._check(start_pos < self.max_seq_length)
185185 dim_to_slice = 2 if self .transpose_cache else 1
186186 seq_length = k_val .size (dim_to_slice )
187187 # Replace the entry in the cache for this token
@@ -252,13 +252,13 @@ def forward(
252252
253253 k , v = self .kv_cache .update (input_pos , k , v )
254254 if self .enable_dynamic_shape :
255- start_pos = input_pos [- 1 ].item ()
256- torch ._check_is_size (start_pos )
257- torch ._check (start_pos < self .max_seq_len )
255+ # start_pos = input_pos[-1].item()
256+ # torch._check_is_size(start_pos)
257+ # torch._check(start_pos < self.max_seq_len)
258258 seq_length = q .size (2 )
259259 # pyre-ignore: Incompatible parameter type [6]
260260 # attn_mask = mask.narrow(0, start_pos, seq_length)
261- attn_mask = mask [start_pos : (start_pos + seq_length )]
261+ attn_mask = mask [input_pos : (input_pos + seq_length )]
262262 else :
263263 attn_mask = mask [None , None , input_pos ]
264264
@@ -515,16 +515,16 @@ def forward(
515515
516516 if self .params .enable_dynamic_shape :
517517 # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
518- input_pos_item = input_pos [- 1 ].item ()
519- torch ._check_is_size (input_pos_item )
520- torch ._check (input_pos_item < self .params .max_seq_len )
518+ # input_pos_item = input_pos[-1].item()
519+ # torch._check_is_size(input_pos_item)
520+ # torch._check(input_pos_item < self.params.max_seq_len)
521521 # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
522522 freqs_cos = self .freqs_cos [
523- input_pos_item : (input_pos_item + seqlen )
523+ input_pos : (input_pos + seqlen )
524524 ] # .narrow(0, input_pos_item, seqlen)
525525 # pyre-ignore: Incompatible parameter type [6]
526526 freqs_sin = self .freqs_sin [
527- input_pos_item : (input_pos_item + seqlen )
527+ input_pos : (input_pos + seqlen )
528528 ] # .narrow(0, input_pos_item, seqlen)
529529 else :
530530 # When not using dynamic shape, use of the .item results in
0 commit comments