@@ -723,18 +723,15 @@ async def generate(
723723 # TODO(lvhan) VLM doesn't support input_ids as an argument.
724724 # Figure out a graceful way to handle the invalid input
725725 prompt_input = dict (input_ids = input_ids )
726+
726727 if gen_config .max_new_tokens is None :
727- # for interactive endpoint, will try maximum possible token num
728- gen_config .max_new_tokens = max (128 , self .session_len - self .id2step [session_id ] - len (input_ids ))
729- elif self .id2step [session_id ] + len (input_ids ) + gen_config .max_new_tokens > self .session_len :
730- gen_config .max_new_tokens = max (self .session_len - self .id2step [session_id ] - len (input_ids ), 128 )
731- logger .error (f'Truncate max_new_tokens to { gen_config .max_new_tokens } ' )
732- if self .id2step [session_id ] + len (input_ids ) + gen_config .max_new_tokens > self .session_len :
733- logger .error (f'run out of tokens. session={ session_id } .' )
734- yield GenOut ('' , self .id2step [session_id ], len (input_ids ), 0 , 'length' )
735- if sequence_end is True and sequence_start is False :
736- await self .end_session (session_id )
737- return
728+ gen_config .max_new_tokens = max (0 , self .session_len - self .id2step [session_id ] - len (input_ids ))
729+ if gen_config .max_new_tokens == 0 :
730+ logger .error (f'run out of tokens. session={ session_id } .' )
731+ yield GenOut ('' , self .id2step [session_id ], len (input_ids ), 0 , 'length' )
732+ if sequence_end is True and sequence_start is False :
733+ await self .end_session (session_id )
734+ return
738735
739736 def is_error (status ):
740737 return status not in [ResponseType .SUCCESS , ResponseType .FINISH ]
@@ -826,8 +823,7 @@ def is_error(status):
826823 metrics_processor .increment_finished_requests ()
827824
828825 if not is_error (outputs .status ):
829- finish_reason = 'length' \
830- if gen_len >= gen_config .max_new_tokens else 'stop'
826+ finish_reason = 'stop' if outputs .token_ids [- 1 ] in stop_ids else 'length'
831827 # utf-8 char at the end means it's a potential unfinished
832828 # byte sequence
833829 if not response .endswith ('�' ):
0 commit comments