Skip to content

Commit c0d5c78

Browse files
JadoTucodego7250
authored andcommitted
[None] [feat] add eos_token_id in generation_config to sampling params (NVIDIA#9514)
Signed-off-by: jiant <[email protected]>
1 parent 15f34f5 commit c0d5c78

File tree

2 files changed

+18
-24
lines changed

2 files changed

+18
-24
lines changed

tensorrt_llm/sampling_params.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -373,14 +373,6 @@ def _setup(
373373
if self.end_id is None:
374374
self.end_id = tokenizer.eos_token_id
375375
self.pad_id = tokenizer.pad_token_id
376-
# kimi_k2 model uses the eos_token_id in generation config
377-
if (
378-
hf_model_config is not None
379-
and hf_model_config.model_type == "kimi_k2"
380-
and generation_config is not None
381-
and isinstance(generation_config.eos_token_id, int)
382-
):
383-
self.end_id = generation_config.eos_token_id
384376

385377
if self.pad_id is None:
386378
self.pad_id = self.end_id
@@ -400,24 +392,26 @@ def _encode(tokenizer, text, add_special_tokens):
400392
strs = [self.stop] if isinstance(self.stop, str) else self.stop
401393
self._stop_word_ids = [_encode(tokenizer, s, add_special_tokens) for s in strs]
402394

403-
# add generation_config to stop word list, only in qwen3-next now
404-
if (
405-
hf_model_config is not None
406-
and hf_model_config.model_type == "qwen3_next"
407-
and generation_config is not None
408-
and isinstance(generation_config.eos_token_id, List)
409-
and all(isinstance(i, int) for i in generation_config.eos_token_id)
410-
):
411-
if self._stop_word_ids:
395+
# Add eos_token_id in generation_config to _stop_word_ids
396+
# Refer to https://huggingface.co/docs/hub/en/transformers#transformers-repository-files and
397+
# https://github.com/huggingface/transformers/blob/1ae4d917ed3badbdb1ffc167e0529f5a6d3c080d/src/transformers/generation/stopping_criteria.py#L451C1-L451C42
398+
# The eos_token_id in generation_config are really mean to stop the text generation.
399+
if generation_config is not None and generation_config.eos_token_id is not None:
400+
if isinstance(generation_config.eos_token_id, int):
401+
generation_eos_token_ids = [generation_config.eos_token_id]
402+
else: # always List[int]
403+
generation_eos_token_ids = generation_config.eos_token_id
404+
405+
if self._stop_word_ids is None:
406+
self._stop_word_ids = [generation_eos_token_ids]
407+
else:
412408
all_stop_tokens_id = set(i for sublist in self._stop_word_ids for i in sublist)
413-
from_generation_stop_tokens = [
414-
i for i in generation_config.eos_token_id if i not in all_stop_tokens_id
409+
from_generation_stop_token_ids = [
410+
i for i in generation_eos_token_ids if i not in all_stop_tokens_id
415411
]
416412

417-
if from_generation_stop_tokens:
418-
self._stop_word_ids.append(from_generation_stop_tokens)
419-
else:
420-
self._stop_word_ids = [generation_config.eos_token_id]
413+
if from_generation_stop_token_ids:
414+
self._stop_word_ids.append(from_generation_stop_token_ids)
421415

422416
return self
423417

tests/unittest/llmapi/apps/_test_trtllm_serve_top_logprobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ async def test_chat_completion_top1_logprobs(async_client: openai.AsyncOpenAI,
110110
"content": "You are a helpful assistant."
111111
}, {
112112
"role": "user",
113-
"content": "What is the capital of France?"
113+
"content": "What is the capital of France? please in detail."
114114
}]
115115
# Test top_logprobs=1
116116
chat_completion = await async_client.chat.completions.create(

0 commit comments

Comments
 (0)