Skip to content

Commit c4b4c8f

Browse files
committed
Pass temperature with ExLlamaV2Sampler.Settings() and also made the EOS compatible with Llama3 and 'stop_at' string from outlines
1 parent 3855734 commit c4b4c8f

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

llm_exl2_dynamic_gen.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -201,13 +201,14 @@ def display(self):
201201
if self.console_line is not None:
202202
print(term.move_xy(0, self.console_line) + self.display_text)
203203

204-
def get_stop_conditions(prompt_format, tokenizer):
205-
if prompt_format == "llama":
204+
def get_stop_conditions(tokenizer):
205+
# get_stop_condition special case if model is llama3
206+
if "llama3" in repo_str:
207+
return [tokenizer.single_id("<|eot_id|>"), tokenizer.eos_token_id]
208+
# elif prompt_format == "granite":
209+
# return [tokenizer.eos_token_id, "\n\nQuestion:"]
210+
else:
206211
return [tokenizer.eos_token_id]
207-
elif prompt_format == "llama3":
208-
return [tokenizer.single_id("<|eot_id|>")]
209-
elif prompt_format == "granite":
210-
return [tokenizer.eos_token_id, "\n\nQuestion:"]
211212

212213
config = configparser.ConfigParser()
213214
config.read('config.ini')
@@ -466,11 +467,19 @@ def process_prompts():
466467
#streamer.append(stream)
467468
#prompt_ids.append(prompt_id)
468469

470+
preferred_eos = get_stop_conditions(tokenizer)
471+
472+
if stop_at is not None:
473+
preferred_eos.append(stop_at)
474+
475+
gen_settings = ExLlamaV2Sampler.Settings()
476+
gen_settings.temperature = 1.0 if temperature>1 else temperature # To make sure the temperature value does not exceed 1
477+
469478
job = ExLlamaV2DynamicJob(
470479
input_ids = ids,
471480
max_new_tokens = max_tokens,
472-
stop_conditions = [tokenizer.eos_token_id] if stop_at is None else [tokenizer.eos_token_id, stop_at],
473-
gen_settings = ExLlamaV2Sampler.Settings(),
481+
stop_conditions = preferred_eos if stop_at is None else [tokenizer.eos_token_id, stop_at],
482+
gen_settings = gen_settings,
474483
filters = filters,
475484
token_healing = healing
476485
)

0 commit comments

Comments
 (0)