@@ -373,6 +373,14 @@ def _setup(
373373 if self .end_id is None :
374374 self .end_id = tokenizer .eos_token_id
375375 self .pad_id = tokenizer .pad_token_id
376+ # kimi_k2 model uses the eos_token_id in generation config
377+ if (
378+ hf_model_config is not None
379+ and hf_model_config .model_type == "kimi_k2"
380+ and generation_config is not None
381+ and isinstance (generation_config .eos_token_id , int )
382+ ):
383+ self .end_id = generation_config .eos_token_id
376384
377385 if self .pad_id is None :
378386 self .pad_id = self .end_id
@@ -392,26 +400,24 @@ def _encode(tokenizer, text, add_special_tokens):
392400 strs = [self .stop ] if isinstance (self .stop , str ) else self .stop
393401 self ._stop_word_ids = [_encode (tokenizer , s , add_special_tokens ) for s in strs ]
394402
395- # Add eos_token_id in generation_config to _stop_word_ids
396- # Refer to https://huggingface.co/docs/hub/en/transformers#transformers-repository-files and
397- # https://github.com/huggingface/transformers/blob/1ae4d917ed3badbdb1ffc167e0529f5a6d3c080d/src/transformers/generation/stopping_criteria.py#L451C1-L451C42
398- # The eos_token_id in generation_config are really mean to stop the text generation.
399- if generation_config is not None and generation_config .eos_token_id is not None :
400- if isinstance (generation_config .eos_token_id , int ):
401- generation_eos_token_ids = [generation_config .eos_token_id ]
402- else : # always List[int]
403- generation_eos_token_ids = generation_config .eos_token_id
404-
405- if self ._stop_word_ids is None :
406- self ._stop_word_ids = [generation_eos_token_ids ]
407- else :
403+ # add generation_config to stop word list, only in qwen3-next now
404+ if (
405+ hf_model_config is not None
406+ and hf_model_config .model_type == "qwen3_next"
407+ and generation_config is not None
408+ and isinstance (generation_config .eos_token_id , List )
409+ and all (isinstance (i , int ) for i in generation_config .eos_token_id )
410+ ):
411+ if self ._stop_word_ids :
408412 all_stop_tokens_id = set (i for sublist in self ._stop_word_ids for i in sublist )
409- from_generation_stop_token_ids = [
410- i for i in generation_eos_token_ids if i not in all_stop_tokens_id
413+ from_generation_stop_tokens = [
414+ i for i in generation_config . eos_token_id if i not in all_stop_tokens_id
411415 ]
412416
413- if from_generation_stop_token_ids :
414- self ._stop_word_ids .append (from_generation_stop_token_ids )
417+ if from_generation_stop_tokens :
418+ self ._stop_word_ids .append (from_generation_stop_tokens )
419+ else :
420+ self ._stop_word_ids = [generation_config .eos_token_id ]
415421
416422 return self
417423
0 commit comments