Skip to content

Commit bc87b22

Browse files
authored
Fix the logic of calculating max_new_tokens and determining finish_reason (#3727)
* fix the logic of computing max_new_tokens * update * fix
1 parent 3de7f14 commit bc87b22

File tree

2 files changed

+10
-13
lines changed

2 files changed

+10
-13
lines changed

lmdeploy/serve/async_engine.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -723,18 +723,15 @@ async def generate(
723723
# TODO(lvhan) VLM doesn't support input_ids as an argument.
724724
# Figure out a graceful way to handle the invalid input
725725
prompt_input = dict(input_ids=input_ids)
726+
726727
if gen_config.max_new_tokens is None:
727-
# for interactive endpoint, will try maximum possible token num
728-
gen_config.max_new_tokens = max(128, self.session_len - self.id2step[session_id] - len(input_ids))
729-
elif self.id2step[session_id] + len(input_ids) + gen_config.max_new_tokens > self.session_len:
730-
gen_config.max_new_tokens = max(self.session_len - self.id2step[session_id] - len(input_ids), 128)
731-
logger.error(f'Truncate max_new_tokens to {gen_config.max_new_tokens}')
732-
if self.id2step[session_id] + len(input_ids) + gen_config.max_new_tokens > self.session_len:
733-
logger.error(f'run out of tokens. session={session_id}.')
734-
yield GenOut('', self.id2step[session_id], len(input_ids), 0, 'length')
735-
if sequence_end is True and sequence_start is False:
736-
await self.end_session(session_id)
737-
return
728+
gen_config.max_new_tokens = max(0, self.session_len - self.id2step[session_id] - len(input_ids))
729+
if gen_config.max_new_tokens == 0:
730+
logger.error(f'run out of tokens. session={session_id}.')
731+
yield GenOut('', self.id2step[session_id], len(input_ids), 0, 'length')
732+
if sequence_end is True and sequence_start is False:
733+
await self.end_session(session_id)
734+
return
738735

739736
def is_error(status):
740737
return status not in [ResponseType.SUCCESS, ResponseType.FINISH]
@@ -826,8 +823,7 @@ def is_error(status):
826823
metrics_processor.increment_finished_requests()
827824

828825
if not is_error(outputs.status):
829-
finish_reason = 'length' \
830-
if gen_len >= gen_config.max_new_tokens else 'stop'
826+
finish_reason = 'stop' if outputs.token_ids[-1] in stop_ids else 'length'
831827
# utf-8 char at the end means it's a potential unfinished
832828
# byte sequence
833829
if not response.endswith('�'):

lmdeploy/turbomind/turbomind.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,7 @@ async def async_stream_infer(self,
686686
if status in [7, 8]: # finish / canceled
687687
finish, status = True, 0
688688
elif status:
689+
logger.error(f'internal error. status_code {status}')
689690
yield self._get_error_output()
690691
break
691692

0 commit comments

Comments
 (0)