@@ -313,7 +313,7 @@ def test_bluelm_template(self):
313313 inputs = inputs .to ('cuda:0' )
314314 pred = model .generate (
315315 ** inputs , max_new_tokens = 64 , repetition_penalty = 1.1 )
316- print ( tokenizer .decode (pred .cpu ()[0 ], skip_special_tokens = True ) )
316+ response = tokenizer .decode (pred .cpu ()[0 ], skip_special_tokens = True )
317317 print (f'official response: { response } ' )
318318 #
319319 input_ids_official = inputs ['input_ids' ][0 ].tolist ()
@@ -592,7 +592,7 @@ def test_deepseek_template(self):
592592 'To avoid excessive testing time caused by downloading models and '
593593 'to prevent OOM (Out of Memory) errors.' )
594594 def test_deepseek_coder_template (self ):
595- model_type = ModelType .deepseek_coder_6_7b_chat
595+ model_type = ModelType .deepseek_coder_6_7b_instruct
596596 model , tokenizer = get_model_tokenizer (model_type )
597597 template_type = get_default_template_type (model_type )
598598 template = get_template (template_type , tokenizer )
@@ -620,7 +620,8 @@ def test_deepseek_coder_template(self):
620620 input_ids_official = tokenizer .apply_chat_template (
621621 messages , tokenize = True , add_generation_prompt = True )
622622 inputs = torch .tensor (input_ids_official , device = 'cuda' )[None ]
623- outputs = model .generate (input_ids = inputs )
623+ outputs = model .generate (
624+ input_ids = inputs , eos_token_id = tokenizer .eos_token_id )
624625 response = tokenizer .decode (
625626 outputs [0 , len (inputs [0 ]):], skip_special_tokens = True )
626627 print (f'official response: { response } ' )
0 commit comments