Skip to content

Commit 0582e54

Browse files
authored
[None][fix] modify qwen3-next sampling stop_tokens (NVIDIA#9331)
Signed-off-by: jiant <[email protected]>
1 parent 11a0b27 commit 0582e54

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

tensorrt_llm/sampling_params.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,25 @@ def _encode(tokenizer, text, add_special_tokens):
395395
strs = [self.stop] if isinstance(self.stop, str) else self.stop
396396
self._stop_word_ids = [_encode(tokenizer, s, add_special_tokens) for s in strs]
397397

398+
# add generation_config to stop word list, only in qwen3-next now
399+
if (
400+
hf_model_config is not None
401+
and hf_model_config.model_type == "qwen3_next"
402+
and generation_config is not None
403+
and isinstance(generation_config.eos_token_id, List)
404+
and all(isinstance(i, int) for i in generation_config.eos_token_id)
405+
):
406+
if self._stop_word_ids:
407+
all_stop_tokens_id = set(i for sublist in self._stop_word_ids for i in sublist)
408+
from_generation_stop_tokens = [
409+
i for i in generation_config.eos_token_id if i not in all_stop_tokens_id
410+
]
411+
412+
if from_generation_stop_tokens:
413+
self._stop_word_ids.append(from_generation_stop_tokens)
414+
else:
415+
self._stop_word_ids = [generation_config.eos_token_id]
416+
398417
return self
399418

400419
def _get_bad_words(self) -> List[List[int]]:

0 commit comments

Comments
 (0)