Skip to content

Commit 44b0f8c

Browse files
authored
[None] [fix] Revert "[None] [feat] add eos_token_id in generation_config to sampling params" (#10002)
1 parent 63e7a2f commit 44b0f8c

File tree

2 files changed

+24
-18
lines changed

2 files changed

+24
-18
lines changed

tensorrt_llm/sampling_params.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,14 @@ def _setup(
373373
if self.end_id is None:
374374
self.end_id = tokenizer.eos_token_id
375375
self.pad_id = tokenizer.pad_token_id
376+
# kimi_k2 model uses the eos_token_id in generation config
377+
if (
378+
hf_model_config is not None
379+
and hf_model_config.model_type == "kimi_k2"
380+
and generation_config is not None
381+
and isinstance(generation_config.eos_token_id, int)
382+
):
383+
self.end_id = generation_config.eos_token_id
376384

377385
if self.pad_id is None:
378386
self.pad_id = self.end_id
@@ -392,26 +400,24 @@ def _encode(tokenizer, text, add_special_tokens):
392400
strs = [self.stop] if isinstance(self.stop, str) else self.stop
393401
self._stop_word_ids = [_encode(tokenizer, s, add_special_tokens) for s in strs]
394402

395-
# Add eos_token_id in generation_config to _stop_word_ids
396-
# Refer to https://huggingface.co/docs/hub/en/transformers#transformers-repository-files and
397-
# https://github.com/huggingface/transformers/blob/1ae4d917ed3badbdb1ffc167e0529f5a6d3c080d/src/transformers/generation/stopping_criteria.py#L451C1-L451C42
398-
# The eos_token_id in generation_config are really mean to stop the text generation.
399-
if generation_config is not None and generation_config.eos_token_id is not None:
400-
if isinstance(generation_config.eos_token_id, int):
401-
generation_eos_token_ids = [generation_config.eos_token_id]
402-
else: # always List[int]
403-
generation_eos_token_ids = generation_config.eos_token_id
404-
405-
if self._stop_word_ids is None:
406-
self._stop_word_ids = [generation_eos_token_ids]
407-
else:
403+
# add generation_config to stop word list, only in qwen3-next now
404+
if (
405+
hf_model_config is not None
406+
and hf_model_config.model_type == "qwen3_next"
407+
and generation_config is not None
408+
and isinstance(generation_config.eos_token_id, List)
409+
and all(isinstance(i, int) for i in generation_config.eos_token_id)
410+
):
411+
if self._stop_word_ids:
408412
all_stop_tokens_id = set(i for sublist in self._stop_word_ids for i in sublist)
409-
from_generation_stop_token_ids = [
410-
i for i in generation_eos_token_ids if i not in all_stop_tokens_id
413+
from_generation_stop_tokens = [
414+
i for i in generation_config.eos_token_id if i not in all_stop_tokens_id
411415
]
412416

413-
if from_generation_stop_token_ids:
414-
self._stop_word_ids.append(from_generation_stop_token_ids)
417+
if from_generation_stop_tokens:
418+
self._stop_word_ids.append(from_generation_stop_tokens)
419+
else:
420+
self._stop_word_ids = [generation_config.eos_token_id]
415421

416422
return self
417423

tests/unittest/llmapi/apps/_test_trtllm_serve_top_logprobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ async def test_chat_completion_top1_logprobs(async_client: openai.AsyncOpenAI,
110110
"content": "You are a helpful assistant."
111111
}, {
112112
"role": "user",
113-
"content": "What is the capital of France? please in detail."
113+
"content": "What is the capital of France?"
114114
}]
115115
# Test top_logprobs=1
116116
chat_completion = await async_client.chat.completions.create(

0 commit comments

Comments
 (0)