@@ -373,14 +373,6 @@ def _setup(
373373 if self .end_id is None :
374374 self .end_id = tokenizer .eos_token_id
375375 self .pad_id = tokenizer .pad_token_id
376- # kimi_k2 model uses the eos_token_id in generation config
377- if (
378- hf_model_config is not None
379- and hf_model_config .model_type == "kimi_k2"
380- and generation_config is not None
381- and isinstance (generation_config .eos_token_id , int )
382- ):
383- self .end_id = generation_config .eos_token_id
384376
385377 if self .pad_id is None :
386378 self .pad_id = self .end_id
@@ -400,24 +392,26 @@ def _encode(tokenizer, text, add_special_tokens):
400392 strs = [self .stop ] if isinstance (self .stop , str ) else self .stop
401393 self ._stop_word_ids = [_encode (tokenizer , s , add_special_tokens ) for s in strs ]
402394
403- # add generation_config to stop word list, only in qwen3-next now
404- if (
405- hf_model_config is not None
406- and hf_model_config .model_type == "qwen3_next"
407- and generation_config is not None
408- and isinstance (generation_config .eos_token_id , List )
409- and all (isinstance (i , int ) for i in generation_config .eos_token_id )
410- ):
411- if self ._stop_word_ids :
395+ # Add eos_token_id in generation_config to _stop_word_ids
396+ # Refer to https://huggingface.co/docs/hub/en/transformers#transformers-repository-files and
397+ # https://github.com/huggingface/transformers/blob/1ae4d917ed3badbdb1ffc167e0529f5a6d3c080d/src/transformers/generation/stopping_criteria.py#L451C1-L451C42
398+ # The eos_token_id in generation_config are really mean to stop the text generation.
399+ if generation_config is not None and generation_config .eos_token_id is not None :
400+ if isinstance (generation_config .eos_token_id , int ):
401+ generation_eos_token_ids = [generation_config .eos_token_id ]
402+ else : # always List[int]
403+ generation_eos_token_ids = generation_config .eos_token_id
404+
405+ if self ._stop_word_ids is None :
406+ self ._stop_word_ids = [generation_eos_token_ids ]
407+ else :
412408 all_stop_tokens_id = set (i for sublist in self ._stop_word_ids for i in sublist )
413- from_generation_stop_tokens = [
414- i for i in generation_config . eos_token_id if i not in all_stop_tokens_id
409+ from_generation_stop_token_ids = [
410+ i for i in generation_eos_token_ids if i not in all_stop_tokens_id
415411 ]
416412
417- if from_generation_stop_tokens :
418- self ._stop_word_ids .append (from_generation_stop_tokens )
419- else :
420- self ._stop_word_ids = [generation_config .eos_token_id ]
413+ if from_generation_stop_token_ids :
414+ self ._stop_word_ids .append (from_generation_stop_token_ids )
421415
422416 return self
423417
0 commit comments