@@ -86,6 +86,11 @@ def generate(
8686 if extra_kwargs is not None :
8787 kwargs .update (extra_kwargs )
8888
89+ # if we didn't specify last_n_tokens and only_last_token is set to True, set last_n_tokens to 1, otherwise use default
90+ # we do this since the output shape of only_last_token is different and therefore would change the logic in generate
91+ if "last_n_tokens" not in kwargs and kwargs .get ("only_last_token" , False ):
92+ kwargs ["last_n_tokens" ] = 1
93+
8994 is_fp8 = "fp8" in kwargs ["attn_name" ]
9095 if isinstance (input_ids , torch .Tensor ):
9196 if len (input_ids .shape ) == 1 :
@@ -233,7 +238,7 @@ def generate(
233238 kwargs ["current_tkv_mask" ] = None
234239 kwargs ["left_padded_prompt_mask" ] = None
235240 kwargs ["use_cache" ] = use_cache
236- only_last_token = kwargs .get ("only_last_token " , False )
241+ last_n_tokens = kwargs .get ("last_n_tokens " , 0 )
237242
238243 prompt_length = input_ids .shape [1 ]
239244
@@ -296,21 +301,20 @@ def generate(
296301 t1 ._scale = current_kv_scales [layer_idx ][0 ][seq_i ].reshape (- 1 )
297302 t2 ._scale = current_kv_scales [layer_idx ][1 ][seq_i ].reshape (- 1 )
298303
299- only_last_token = kwargs .get ("only_last_token " , False )
304+ last_n_tokens = kwargs .get ("last_n_tokens " , 0 )
300305 output , current_kv_cache = model (
301306 input_ids_i ,
302307 slot_mapping = slot_mapping_i ,
303308 position_ids = position_ids_i ,
304309 mask = mask_i ,
305310 past_key_value_states = current_kv_cache ,
306311 use_cache = kwargs ["use_cache" ],
307- only_last_token = only_last_token ,
312+ last_n_tokens = last_n_tokens ,
308313 attn_name = kwargs ["attn_name" ],
309314 )
310315
311316 # only last token must be handled here to properly stack the tensors
312- if not only_last_token :
313- output = output [:, - 1 , :]
317+ output = output [:, - 1 , :]
314318
315319 # TODO: Figure out how to do this cleanly
316320 if "fp8" in kwargs ["attn_name" ]:
@@ -341,6 +345,7 @@ def generate(
341345 kwargs ["position_ids" ] = kwargs ["position_ids" ].clone (
342346 memory_format = torch .contiguous_format
343347 )
348+ kwargs ["last_n_tokens" ] = 1
344349
345350 # we no longer have a global pos_i, each sequence has its own pos_i
346351 slot_mapping = []
@@ -396,8 +401,7 @@ def generate(
396401 # typically this is done outside of prefill/decode logic, but since this logic already exists as part of the
397402 # conditional for prefill (since prefill does this within a loop for each batch size 1 prefill), we also provide
398403 # this same logic as part of the decode conditional
399- if not only_last_token :
400- logits = logits [:, - 1 , :]
404+ logits = logits [:, - 1 , :]
401405
402406 output = (logits , past_key_value_states )
403407
0 commit comments