Skip to content

Commit cfd192d

Browse files
Update generative-proof-of-concept-CPU-preprocessing-in-memory.py
Fix duplicate parameter result, result_0.
1 parent c659a7b commit cfd192d

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

generative-proof-of-concept-CPU-preprocessing-in-memory.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,19 +1089,19 @@ def complete_text_beam(text: str,
10891089
# print(f"I ask the generator (Beam defaults - max_new_tokens: 10, temperature: 0.75, top_k: 75, top_p: 0.98, repetition_penalty: None, presence_penalty: 1.3, frequency_penalty: 1.4): {test_text_block}... It responds: '{response}'.")
10901090

10911091
trial_number = int(trial.number)
1092-
def test_text(test_prompt: str, max_new_tokens: int, sample_number: int, result: float, result_cutoff: float, trial_id: int, test_sample_number: int, result_0: float) -> None:
1092+
def test_text(test_prompt: str, max_new_tokens: int, sample_number: int, result_cutoff: float, trial_id: int, test_sample_number: int, result_0: float) -> None:
10931093
"""
1094-
If the result < result_cutoff, this will run a matrix of different sampling values and print out the resulting text for human subjective evaluation.
1094+
If the result_0 < result_cutoff, this will run a matrix of different sampling values and print out the resulting text for human subjective evaluation.
10951095
10961096
Parameters:
10971097
- test_prompt: a string to prompt generation
10981098
- max_new_tokens: int, number of tokens to generate unless we generate a stop token.
10991099
- sample_number: Metadata for sample...
1100-
- result: Perplexity score from this run
1100+
- result_0: Perplexity score from this run
11011101
- result_cutoff: Perplexity score that would be expected to indicate a trial worth running this pn
11021102
11031103
"""
1104-
if result < result_cutoff:
1104+
if result_0 < result_cutoff:
11051105
generation_param_permutations = [
11061106
# #3
11071107
{
@@ -1258,11 +1258,10 @@ def test_text(test_prompt: str, max_new_tokens: int, sample_number: int, result:
12581258
test_prompt=sample,
12591259
max_new_tokens=MAX_NEW_TOKENS,
12601260
sample_number=counter,
1261-
result=phase_i_a_result,
12621261
result_cutoff=RESULT_CUTOFF,
12631262
trial_id=trial_number,
12641263
test_sample_number=counter,
1265-
result_0=result)
1264+
result_0=phase_i_a_result)
12661265
counter += 1
12671266

12681267
# # Tokenize the text without padding first to get actual tokens

0 commit comments

Comments
 (0)