Skip to content

Commit 5f71aa9

Browse files
authored
[Serving] Use stop strs and token ids for completions (#2534)
This PR applies the stop strings and stop token ids defined in conversation tempalte to the raw text completions. So that whenever the model outputs a stop token id or stop string, the raw generation can stop. Prior to this commit, the raw text never stops when the max tokens is not given. This commit helps reduce the frequency of such events. Nevertheless, if the model does not output a stop string/token id, the generation will still not be going to stop.
1 parent 50a1a7c commit 5f71aa9

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

python/mlc_llm/serve/engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,6 +1298,7 @@ async def _handle_completion(
12981298
self.state,
12991299
self.tokenizer,
13001300
self.max_input_sequence_length,
1301+
self.conv_template.model_copy(deep=True),
13011302
)
13021303
_ = prompt_length
13031304
if echo_response is not None:
@@ -1840,6 +1841,7 @@ def _handle_completion(
18401841
self.state,
18411842
self.tokenizer,
18421843
self.max_input_sequence_length,
1844+
self.conv_template.model_copy(deep=True),
18431845
)
18441846
_ = prompt_length
18451847
if echo_response is not None:

python/mlc_llm/serve/engine_base.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -862,12 +862,13 @@ def process_chat_completion_stream_output( # pylint: disable=too-many-arguments
862862
return response
863863

864864

865-
def process_completion_request(
865+
def process_completion_request( # pylint: disable=too-many-arguments
866866
request: openai_api_protocol.CompletionRequest,
867867
request_id: str,
868868
engine_state: EngineState,
869869
tokenizer: Tokenizer,
870870
max_input_sequence_length: int,
871+
conv_template: Conversation,
871872
) -> Tuple[List[int], GenerationConfig, int, Optional[openai_api_protocol.CompletionResponse]]:
872873
"""Process the given CompletionRequest, apply request validity
873874
checks, and return the processed prompts, and other info.
@@ -889,6 +890,9 @@ def process_completion_request(
889890
max_input_sequence_length : int
890891
The maximum allowed total prompt length.
891892
893+
conv_template : Conversation
894+
The conversation template of the model.
895+
892896
Returns
893897
-------
894898
prompt : List[int]
@@ -917,7 +921,11 @@ def process_completion_request(
917921
assert isinstance(prompt, list)
918922

919923
# Process generation config. Create request id.
920-
generation_cfg = engine_utils.get_generation_config(request)
924+
generation_cfg = engine_utils.get_generation_config(
925+
request,
926+
extra_stop_token_ids=conv_template.stop_token_ids,
927+
extra_stop_str=conv_template.stop_str,
928+
)
921929

922930
# - Echo back the prompt.
923931
echo_response = None

0 commit comments

Comments
 (0)