@@ -201,13 +201,14 @@ def display(self):
201201 if self .console_line is not None :
202202 print (term .move_xy (0 , self .console_line ) + self .display_text )
203203
204- def get_stop_conditions (prompt_format , tokenizer ):
205- if prompt_format == "llama" :
204+ def get_stop_conditions (tokenizer ):
205+ # get_stop_condition special case if model is llama3
206+ if "llama3" in repo_str :
207+ return [tokenizer .single_id ("<|eot_id|>" ), tokenizer .eos_token_id ]
208+ # elif prompt_format == "granite":
209+ # return [tokenizer.eos_token_id, "\n\nQuestion:"]
210+ else :
206211 return [tokenizer .eos_token_id ]
207- elif prompt_format == "llama3" :
208- return [tokenizer .single_id ("<|eot_id|>" )]
209- elif prompt_format == "granite" :
210- return [tokenizer .eos_token_id , "\n \n Question:" ]
211212
212213config = configparser .ConfigParser ()
213214config .read ('config.ini' )
@@ -466,11 +467,19 @@ def process_prompts():
466467 #streamer.append(stream)
467468 #prompt_ids.append(prompt_id)
468469
470+ preferred_eos = get_stop_conditions (tokenizer )
471+
472+ if stop_at is not None :
473+ preferred_eos .append (stop_at )
474+
475+ gen_settings = ExLlamaV2Sampler .Settings ()
476+ gen_settings .temperature = 1.0 if temperature > 1 else temperature # To make sure the temperature value does not exceed 1
477+
469478 job = ExLlamaV2DynamicJob (
470479 input_ids = ids ,
471480 max_new_tokens = max_tokens ,
472- stop_conditions = [ tokenizer . eos_token_id ] if stop_at is None else [tokenizer .eos_token_id , stop_at ],
473- gen_settings = ExLlamaV2Sampler . Settings () ,
481+ stop_conditions = preferred_eos if stop_at is None else [tokenizer .eos_token_id , stop_at ],
482+ gen_settings = gen_settings ,
474483 filters = filters ,
475484 token_healing = healing
476485 )
0 commit comments