Skip to content

Commit d8ae574

Browse files
author
Lazaro Hurtado
committed
small optimizations
1 parent f2eb3de commit d8ae574

File tree

7 files changed

+126
-72
lines changed

7 files changed

+126
-72
lines changed

Makefile

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,9 @@ create_venv:
1010
clean:
1111
find . -type d -name "__pycache__" -exec rm -rf {} +
1212

13-
destroy: clean
13+
reset_run:
14+
find . -type d -name "results" -exec rm -rf {} +
15+
find . -type d -name "contexts" -exec rm -rf {} +
16+
17+
destroy: clean reset_run
1418
rm -rf ./$(VENV_NAME)

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ $ pip install -r requirements.txt
1919
You can then run the analysis on OpenAI or Anthropic models by running `main.py` with the command line arguments shown below. `LLMNeedleHaystackTester` parameters can also be passed as command line arguments, except `model_to_test` and `evaluator` of course.
2020
* `provider` - The provider of the model, available options are `openai` and `anthropic`. Defaults to `openai`
2121
* `evaluator` - The provider for the evaluator model, only `openai` is currently supported. Defaults to `openai`.
22+
* `model_name` - Model name of the language model accessible by the provider. Defaults to `gpt-3.5-turbo-0125`
23+
* `evaluator_model_name` - Model name of the language model accessible by the evaluator. Defaults to `gpt-3.5-turbo-0125`
2224
* `api_key` - API key for either OpenAI or Anthropic provider. Can either be passed as a command line argument or an environment variable named `OPENAI_API_KEY` or `ANTHROPIC_API_KEY` depending on the provider. Defaults to `None`.
2325
* `evaluator_api_key` - API key for OpenAI provider. Can either be passed as a command line argument or an environment variable named `OPENAI_API_KEY`. Defaults to `None`
2426

main.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
class CommandArgs():
1414
provider: str = "openai"
1515
evaluator: str = "openai"
16+
model_name: Optional[str] = "gpt-3.5-turbo-0125"
17+
evaluator_model_name: Optional[str] = "gpt-3.5-turbo-0125"
1618
api_key: Optional[str] = None
1719
evaluator_api_key: Optional[str] = None
1820
needle: Optional[str] = "\nThe best thing to do in San Francisco is eat a sandwich and sit in Dolores Park on a sunny day.\n"
@@ -38,16 +40,17 @@ class CommandArgs():
3840
def get_model_to_test(args: CommandArgs) -> ModelProvider:
3941
match args.provider.lower():
4042
case "openai":
41-
return OpenAI(api_key=args.api_key)
43+
return OpenAI(model_name=args.model_name, api_key=args.api_key)
4244
case "anthropic":
43-
return Anthropic(api_key=args.api_key)
45+
return Anthropic(model_name=args.model_name, api_key=args.api_key)
4446
case _:
4547
raise ValueError(f"Invalid provider: {args.provider}")
4648

4749
def get_evaluator(args: CommandArgs) -> Evaluator:
4850
match args.evaluator.lower():
4951
case "openai":
50-
return OpenAIEvaluator(question_asked=args.retrieval_question,
52+
return OpenAIEvaluator(model_name=args.evaluator_model_name,
53+
question_asked=args.retrieval_question,
5154
true_answer=args.needle,
5255
api_key=args.evaluator_api_key)
5356
case _:

src/evaluators/openai_evaluator.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from langchain_community.chat_models import ChatOpenAI
88

99
class OpenAIEvaluator(Evaluator):
10+
DEFAULT_MODEL_KWARGS: dict = dict(temperature=0)
1011
CRITERIA = {"accuracy": """
1112
Score 1: The answer is completely unrelated to the reference.
1213
Score 3: The answer has minor relevance but does not align with the reference.
@@ -17,11 +18,13 @@ class OpenAIEvaluator(Evaluator):
1718

1819
def __init__(self,
1920
model_name: str = "gpt-3.5-turbo-0125",
21+
model_kwargs: dict = DEFAULT_MODEL_KWARGS,
2022
api_key: str = None,
2123
true_answer: str = None,
22-
question_asked: str = None):
24+
question_asked: str = None,):
2325
"""
2426
:param model_name: The name of the model.
27+
:param model_kwargs: Model configuration. Default is {temperature: 0}
2528
:param api_key: The API key for OpenAI. Default is None.
2629
:param true_answer: The true answer to the question asked.
2730
:param question_asked: The question asked to the model.
@@ -31,6 +34,7 @@ def __init__(self,
3134
raise ValueError("true_answer and question_asked must be supplied with init.")
3235

3336
self.model_name = model_name
37+
self.model_kwargs = model_kwargs
3438
self.true_answer = true_answer
3539
self.question_asked = question_asked
3640

@@ -40,8 +44,8 @@ def __init__(self,
4044
self.api_key = api_key or os.getenv('OPENAI_API_KEY')
4145

4246
self.evaluator = ChatOpenAI(model=self.model_name,
43-
temperature=0,
44-
openai_api_key=self.api_key)
47+
openai_api_key=self.api_key,
48+
**self.model_kwargs)
4549

4650
def evaluate_response(self, response: str) -> int:
4751
evaluator = load_evaluator(

src/llm_needle_haystack_tester.py

Lines changed: 86 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def __init__(self,
6464
"""
6565
if not model_to_test:
6666
raise ValueError("A language model must be provided to test.")
67+
if not evaluator:
68+
raise ValueError("An evaluator must be provided to evaluate the model's response.")
6769
if not needle or not haystack_dir or not retrieval_question:
6870
raise ValueError("Needle, haystack, and retrieval_question must be provided.")
6971

@@ -77,13 +79,20 @@ def __init__(self,
7779
self.save_contexts = save_contexts
7880
self.seconds_to_sleep_between_completions = seconds_to_sleep_between_completions
7981
self.print_ongoing_status = print_ongoing_status
82+
83+
self.context_dir = 'contexts'
84+
self.results_dir = 'results'
85+
self.result_file_format = '{model_name}_len_{context_length}_depth_{depth_percent}'
8086
self.testing_results = []
8187

8288
if context_lengths is None:
8389
if context_lengths_min is None or context_lengths_max is None or context_lengths_num_intervals is None:
8490
raise ValueError("Either context_lengths_min, context_lengths_max, context_lengths_intervals need to be filled out OR the context_lengths_list needs to be supplied.")
8591
else:
86-
self.context_lengths = np.round(np.linspace(context_lengths_min, context_lengths_max, num=context_lengths_num_intervals, endpoint=True)).astype(int)
92+
self.context_lengths = self.get_intervals(context_lengths_min,
93+
context_lengths_max,
94+
context_lengths_num_intervals,
95+
"linear")
8796
else:
8897
self.context_lengths = context_lengths
8998

@@ -93,13 +102,13 @@ def __init__(self,
93102
if document_depth_percents is None:
94103
if document_depth_percent_min is None or document_depth_percent_max is None or document_depth_percent_intervals is None:
95104
raise ValueError("Either document_depth_percent_min, document_depth_percent_max, document_depth_percent_intervals need to be filled out OR the document_depth_percents needs to be supplied.")
96-
97-
if document_depth_percent_interval_type == 'linear':
98-
self.document_depth_percents = np.round(np.linspace(document_depth_percent_min, document_depth_percent_max, num=document_depth_percent_intervals, endpoint=True)).astype(int)
99-
elif document_depth_percent_interval_type == 'sigmoid':
100-
self.document_depth_percents = [self.logistic(x) for x in np.linspace(document_depth_percent_min, document_depth_percent_max, document_depth_percent_intervals)]
101-
else:
105+
if document_depth_percent_interval_type not in ['linear', 'sigmoid']:
102106
raise ValueError("document_depth_percent_interval_type must be either 'sigmoid' or 'linear' if document_depth_percents is None.")
107+
108+
self.document_depth_percents = self.get_intervals(document_depth_percent_min,
109+
document_depth_percent_max,
110+
document_depth_percent_intervals,
111+
document_depth_percent_interval_type)
103112
else:
104113
self.document_depth_percents = document_depth_percents
105114

@@ -108,6 +117,17 @@ def __init__(self,
108117

109118
self.evaluation_model = evaluator
110119

120+
def get_intervals(self, min_depth, max_depth, num_intervals, interval_type):
121+
linear_spacing = np.linspace(min_depth, max_depth, num=num_intervals, endpoint=True)
122+
123+
match interval_type:
124+
case 'linear':
125+
return np.round(linear_spacing).astype(int)
126+
case 'sigmoid':
127+
return [self.logistic(x) for x in linear_spacing]
128+
case _:
129+
return []
130+
111131
def logistic(self, x, L=100, x0=50, k=.1):
112132
if x in [0, 100]:
113133
return x
@@ -122,24 +142,20 @@ async def bound_evaluate_and_log(self, sem, *args):
122142
await self.evaluate_and_log(*args)
123143

124144
async def run_test(self):
125-
sem = Semaphore(self.num_concurrent_requests)
126-
127-
# Run through each iteration of context_lengths and depths
128-
tasks = []
129-
for context_length in self.context_lengths:
130-
for depth_percent in self.document_depth_percents:
131-
task = self.bound_evaluate_and_log(sem, context_length, depth_percent)
132-
tasks.append(task)
145+
async with asyncio.TaskGroup() as tg:
146+
sem = Semaphore(self.num_concurrent_requests)
133147

134-
# Wait for all tasks to complete
135-
await asyncio.gather(*tasks)
148+
# Run through each iteration of context_lengths and depths
149+
for context_length in self.context_lengths:
150+
for depth_percent in self.document_depth_percents:
151+
task = self.bound_evaluate_and_log(sem, context_length, depth_percent)
152+
tg.create_task(task)
136153

137154
async def evaluate_and_log(self, context_length, depth_percent):
138155
# Checks to see if you've already checked a length/percent/version.
139156
# This helps if the program stop running and you want to restart later
140-
if self.save_results:
141-
if self.result_exists(context_length, depth_percent):
142-
return
157+
if self.save_results and self.result_exists(context_length, depth_percent):
158+
return
143159

144160
# Go generate the required length context and place your needle statement in
145161
context = await self.generate_context(context_length, depth_percent)
@@ -159,47 +175,45 @@ async def evaluate_and_log(self, context_length, depth_percent):
159175
score = self.evaluation_model.evaluate_response(response)
160176

161177
results = {
162-
# 'context' : context, # Uncomment this line if you'd like to save the context the model was asked to retrieve from. Warning: This will become very large.
163-
'model' : self.model_name,
164-
'context_length' : int(context_length),
165-
'depth_percent' : float(depth_percent),
166-
'version' : self.results_version,
167-
'needle' : self.needle,
168-
'model_response' : response,
169-
'score' : score,
170-
'test_duration_seconds' : test_elapsed_time,
171-
'test_timestamp_utc' : datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S%z')
178+
# 'context': context, # Uncomment this line if you'd like to save the context the model was asked to retrieve from. Warning: This will become very large.
179+
'model': self.model_name,
180+
'context_length': int(context_length),
181+
'depth_percent': float(depth_percent),
182+
'version': self.results_version,
183+
'needle': self.needle,
184+
'model_response': response,
185+
'score': score,
186+
'test_duration_seconds': test_elapsed_time,
187+
'test_timestamp_utc': datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S%z')
172188
}
173189

174190
self.testing_results.append(results)
175191

176192
if self.print_ongoing_status:
177-
print (f"-- Test Summary -- ")
178-
print (f"Duration: {test_elapsed_time:.1f} seconds")
179-
print (f"Context: {context_length} tokens")
180-
print (f"Depth: {depth_percent}%")
181-
print (f"Score: {score}")
182-
print (f"Response: {response}\n")
193+
self.print_status(test_elapsed_time, context_length, depth_percent, score, response)
183194

184-
context_file_location = f'{self.model_name.replace(".", "_")}_len_{context_length}_depth_{int(depth_percent*100)}'
195+
parsed_model_name = self.model_name.replace(".", "_")
196+
context_file_location = self.result_file_format.format(model_name=parsed_model_name,
197+
context_length=context_length,
198+
depth_percent=int(depth_percent))
185199

186200
if self.save_contexts:
187201
results['file_name'] = context_file_location
188202

189203
# Save the context to file for retesting
190-
if not os.path.exists('contexts'):
191-
os.makedirs('contexts')
204+
if not os.path.exists(self.context_dir):
205+
os.makedirs(self.context_dir)
192206

193-
with open(f'contexts/{context_file_location}_context.txt', 'w') as f:
207+
with open(f'{self.context_dir}/{context_file_location}_context.txt', 'w') as f:
194208
f.write(context)
195209

196210
if self.save_results:
197211
# Save the context to file for retesting
198-
if not os.path.exists('results'):
199-
os.makedirs('results')
212+
if not os.path.exists(self.results_dir):
213+
os.makedirs(self.results_dir)
200214

201215
# Save the result to file for retesting
202-
with open(f'results/{context_file_location}_results.json', 'w') as f:
216+
with open(f'{self.results_dir}/{context_file_location}_results.json', 'w') as f:
203217
json.dump(results, f)
204218

205219
if self.seconds_to_sleep_between_completions:
@@ -210,20 +224,21 @@ def result_exists(self, context_length, depth_percent):
210224
Checks to see if a result has already been evaluated or not
211225
"""
212226

213-
results_dir = 'results/'
214-
if not os.path.exists(results_dir):
227+
if not os.path.exists(self.results_dir):
228+
return False
229+
230+
filename = self.result_file_format.format(model_name=self.model_name,
231+
context_length=context_length,
232+
depth_percent=depth_percent)
233+
file_path = os.path.join(self.results_dir, f'{filename}.json')
234+
if not os.path.exists(file_path):
215235
return False
216236

217-
for filename in os.listdir(results_dir):
218-
if filename.endswith('.json'):
219-
with open(os.path.join(results_dir, filename), 'r') as f:
220-
result = json.load(f)
221-
context_length_met = result['context_length'] == context_length
222-
depth_percent_met = result['depth_percent'] == depth_percent
223-
version_met = result.get('version', 1) == self.results_version
224-
model_met = result['model'] == self.model_name
225-
if context_length_met and depth_percent_met and version_met and model_met:
226-
return True
237+
with open(file_path, 'r') as f:
238+
result = json.load(f)
239+
240+
if result.get('version', 1) == self.results_version:
241+
return True
227242
return False
228243

229244
async def generate_context(self, context_length, depth_percent):
@@ -265,9 +280,9 @@ def insert_needle(self, context, depth_percent, context_length):
265280
period_tokens = self.model_to_test.encode_text_to_tokens('.')
266281

267282
# Then we iteration backwards until we find the first period
268-
while tokens_new_context and tokens_new_context[-1] not in period_tokens:
283+
while (insertion_point > 0) and (tokens_new_context[insertion_point-1] != period_tokens):
269284
insertion_point -= 1
270-
tokens_new_context = tokens_context[:insertion_point]
285+
tokens_new_context = tokens_context[:insertion_point]
271286

272287
# Once we get there, then add in your needle, and stick the rest of your context in on the other end.
273288
# Now we have a needle in a haystack
@@ -282,12 +297,16 @@ def get_context_length_in_tokens(self, context):
282297

283298
def read_context_files(self):
284299
context = ""
300+
current_context_length = 0
285301
max_context_length = max(self.context_lengths)
286302

287-
while self.get_context_length_in_tokens(context) < max_context_length:
303+
while current_context_length < max_context_length:
288304
for file in glob.glob(f"{self.haystack_dir}/*.txt"):
289305
with open(file, 'r') as f:
290-
context += f.read()
306+
file_content = f.read()
307+
308+
context += file_content
309+
current_context_length += self.get_context_length_in_tokens(file_content)
291310
return context
292311

293312
def encode_and_trim(self, context, context_length):
@@ -308,6 +327,14 @@ def print_start_test_summary(self):
308327
print (f"- Needle: {self.needle.strip()}")
309328
print ("\n\n")
310329

330+
def print_status(self, elapsed_time, context_length, depth_percent, score, response):
331+
print (f"-- Test Summary -- ")
332+
print (f"Duration: {elapsed_time:.1f} seconds")
333+
print (f"Context: {context_length} tokens")
334+
print (f"Depth: {depth_percent}%")
335+
print (f"Score: {score}")
336+
print (f"Response: {response}\n")
337+
311338
def start_test(self):
312339
if self.print_ongoing_status:
313340
self.print_start_test_summary()

src/providers/anthropic.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,16 @@
66
from typing import Optional
77

88
class Anthropic(ModelProvider):
9-
def __init__(self, model_name: str = "claude", api_key: str = None):
9+
DEFAULT_MODEL_KWARGS: dict = dict(max_tokens_to_sample = 300,
10+
temperature = 0)
11+
12+
def __init__(self,
13+
model_name: str = "claude",
14+
model_kwargs: dict = DEFAULT_MODEL_KWARGS,
15+
api_key: str = None):
1016
"""
1117
:param model_name: The name of the model. Default is 'claude'.
18+
:param model_kwargs: Model configuration. Default is {max_tokens_to_sample: 300, temperature: 0}
1219
:param api_key: The API key for Anthropic. Default is None.
1320
"""
1421

@@ -19,6 +26,7 @@ def __init__(self, model_name: str = "claude", api_key: str = None):
1926
raise ValueError("Either api_key must be supplied with init, or ANTHROPIC_API_KEY must be in env.")
2027

2128
self.model_name = model_name
29+
self.model_kwargs = model_kwargs
2230
self.api_key = api_key or os.getenv('ANTHROPIC_API_KEY')
2331

2432
self.model = AsyncAnthropic(api_key=self.api_key)
@@ -32,9 +40,8 @@ def __init__(self, model_name: str = "claude", api_key: str = None):
3240
async def evaluate_model(self, prompt: str) -> str:
3341
response = await self.model.completions.create(
3442
model=self.model_name,
35-
max_tokens_to_sample=300,
3643
prompt=prompt,
37-
temperature=0)
44+
**self.model_kwargs)
3845
return response.completion
3946

4047
def generate_prompt(self, context: str, retrieval_question: str) -> str | list[dict[str, str]]:

0 commit comments

Comments
 (0)