Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 88 additions & 59 deletions needlehaystack/llm_needle_haystack_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ class LLMNeedleHaystackTester:
"""
This class is used to test the LLM Needle Haystack.
"""

CONTEXT_DIR = 'contexts'
RESULTS_DIR = 'results'
RESULT_FILE_FORMAT = '{model_name}_len_{context_length}_depth_{depth_percent}'

def __init__(self,
model_to_test: ModelProvider = None,
evaluator: Evaluator = None,
Expand Down Expand Up @@ -65,6 +70,8 @@ def __init__(self,
"""
if not model_to_test:
raise ValueError("A language model must be provided to test.")
if not evaluator:
raise ValueError("An evaluator must be provided to evaluate the model's response.")
if not needle or not haystack_dir or not retrieval_question:
raise ValueError("Needle, haystack, and retrieval_question must be provided.")

Expand All @@ -78,13 +85,17 @@ def __init__(self,
self.save_contexts = save_contexts
self.seconds_to_sleep_between_completions = seconds_to_sleep_between_completions
self.print_ongoing_status = print_ongoing_status

self.testing_results = []

if context_lengths is None:
if context_lengths_min is None or context_lengths_max is None or context_lengths_num_intervals is None:
raise ValueError("Either context_lengths_min, context_lengths_max, context_lengths_intervals need to be filled out OR the context_lengths_list needs to be supplied.")
else:
self.context_lengths = np.round(np.linspace(context_lengths_min, context_lengths_max, num=context_lengths_num_intervals, endpoint=True)).astype(int)
self.context_lengths = self.get_intervals(context_lengths_min,
context_lengths_max,
context_lengths_num_intervals,
"linear")
else:
self.context_lengths = context_lengths

Expand All @@ -94,13 +105,13 @@ def __init__(self,
if document_depth_percents is None:
if document_depth_percent_min is None or document_depth_percent_max is None or document_depth_percent_intervals is None:
raise ValueError("Either document_depth_percent_min, document_depth_percent_max, document_depth_percent_intervals need to be filled out OR the document_depth_percents needs to be supplied.")

if document_depth_percent_interval_type == 'linear':
self.document_depth_percents = np.round(np.linspace(document_depth_percent_min, document_depth_percent_max, num=document_depth_percent_intervals, endpoint=True)).astype(int)
elif document_depth_percent_interval_type == 'sigmoid':
self.document_depth_percents = [self.logistic(x) for x in np.linspace(document_depth_percent_min, document_depth_percent_max, document_depth_percent_intervals)]
else:
if document_depth_percent_interval_type not in ['linear', 'sigmoid']:
raise ValueError("document_depth_percent_interval_type must be either 'sigmoid' or 'linear' if document_depth_percents is None.")

self.document_depth_percents = self.get_intervals(document_depth_percent_min,
document_depth_percent_max,
document_depth_percent_intervals,
document_depth_percent_interval_type)
else:
self.document_depth_percents = document_depth_percents

Expand All @@ -109,6 +120,17 @@ def __init__(self,

self.evaluation_model = evaluator

def get_intervals(self, min_depth, max_depth, num_intervals, interval_type):
linear_spacing = np.linspace(min_depth, max_depth, num=num_intervals, endpoint=True)

match interval_type:
case 'linear':
return np.round(linear_spacing).astype(int)
case 'sigmoid':
return [self.logistic(x) for x in linear_spacing]
case _:
return []

def logistic(self, x, L=100, x0=50, k=.1):
if x in [0, 100]:
return x
Expand All @@ -123,24 +145,20 @@ async def bound_evaluate_and_log(self, sem, *args):
await self.evaluate_and_log(*args)

async def run_test(self):
sem = Semaphore(self.num_concurrent_requests)
async with asyncio.TaskGroup() as tg:
sem = Semaphore(self.num_concurrent_requests)

# Run through each iteration of context_lengths and depths
tasks = []
for context_length in self.context_lengths:
for depth_percent in self.document_depth_percents:
task = self.bound_evaluate_and_log(sem, context_length, depth_percent)
tasks.append(task)

# Wait for all tasks to complete
await asyncio.gather(*tasks)
# Run through each iteration of context_lengths and depths
for context_length in self.context_lengths:
for depth_percent in self.document_depth_percents:
task = self.bound_evaluate_and_log(sem, context_length, depth_percent)
tg.create_task(task)

async def evaluate_and_log(self, context_length, depth_percent):
# Checks to see if you've already checked a length/percent/version.
# This helps if the program stop running and you want to restart later
if self.save_results:
if self.result_exists(context_length, depth_percent):
return
if self.save_results and self.result_exists(context_length, depth_percent):
return

# Go generate the required length context and place your needle statement in
context = await self.generate_context(context_length, depth_percent)
Expand All @@ -160,47 +178,45 @@ async def evaluate_and_log(self, context_length, depth_percent):
score = self.evaluation_model.evaluate_response(response)

results = {
# 'context' : context, # Uncomment this line if you'd like to save the context the model was asked to retrieve from. Warning: This will become very large.
'model' : self.model_name,
'context_length' : int(context_length),
'depth_percent' : float(depth_percent),
'version' : self.results_version,
'needle' : self.needle,
'model_response' : response,
'score' : score,
'test_duration_seconds' : test_elapsed_time,
'test_timestamp_utc' : datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S%z')
# 'context': context, # Uncomment this line if you'd like to save the context the model was asked to retrieve from. Warning: This will become very large.
'model': self.model_name,
'context_length': int(context_length),
'depth_percent': float(depth_percent),
'version': self.results_version,
'needle': self.needle,
'model_response': response,
'score': score,
'test_duration_seconds': test_elapsed_time,
'test_timestamp_utc': datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S%z')
}

self.testing_results.append(results)

if self.print_ongoing_status:
print (f"-- Test Summary -- ")
print (f"Duration: {test_elapsed_time:.1f} seconds")
print (f"Context: {context_length} tokens")
print (f"Depth: {depth_percent}%")
print (f"Score: {score}")
print (f"Response: {response}\n")
self.print_status(test_elapsed_time, context_length, depth_percent, score, response)

context_file_location = f'{self.model_name.replace(".", "_")}_len_{context_length}_depth_{int(depth_percent*100)}'
parsed_model_name = self.model_name.replace(".", "_")
context_file_location = self.RESULT_FILE_FORMAT.format(model_name=parsed_model_name,
context_length=context_length,
depth_percent=int(depth_percent))

if self.save_contexts:
results['file_name'] = context_file_location

# Save the context to file for retesting
if not os.path.exists('contexts'):
os.makedirs('contexts')
if not os.path.exists(self.CONTEXT_DIR):
os.makedirs(self.CONTEXT_DIR)

with open(f'contexts/{context_file_location}_context.txt', 'w') as f:
with open(f'{self.CONTEXT_DIR}/{context_file_location}_context.txt', 'w') as f:
f.write(context)

if self.save_results:
# Save the context to file for retesting
if not os.path.exists('results'):
os.makedirs('results')
if not os.path.exists(self.RESULTS_DIR):
os.makedirs(self.RESULTS_DIR)

# Save the result to file for retesting
with open(f'results/{context_file_location}_results.json', 'w') as f:
with open(f'{self.RESULTS_DIR}/{context_file_location}_results.json', 'w') as f:
json.dump(results, f)

if self.seconds_to_sleep_between_completions:
Expand All @@ -211,20 +227,21 @@ def result_exists(self, context_length, depth_percent):
Checks to see if a result has already been evaluated or not
"""

results_dir = 'results/'
if not os.path.exists(results_dir):
if not os.path.exists(self.RESULTS_DIR):
return False

filename = self.RESULT_FILE_FORMAT.format(model_name=self.model_name,
context_length=context_length,
depth_percent=depth_percent)
file_path = os.path.join(self.RESULTS_DIR, f'{filename}.json')
if not os.path.exists(file_path):
return False

for filename in os.listdir(results_dir):
if filename.endswith('.json'):
with open(os.path.join(results_dir, filename), 'r') as f:
result = json.load(f)
context_length_met = result['context_length'] == context_length
depth_percent_met = result['depth_percent'] == depth_percent
version_met = result.get('version', 1) == self.results_version
model_met = result['model'] == self.model_name
if context_length_met and depth_percent_met and version_met and model_met:
return True
with open(file_path, 'r') as f:
result = json.load(f)

if result.get('version', 1) == self.results_version:
return True
return False

async def generate_context(self, context_length, depth_percent):
Expand Down Expand Up @@ -266,9 +283,9 @@ def insert_needle(self, context, depth_percent, context_length):
period_tokens = self.model_to_test.encode_text_to_tokens('.')

# Then we iteration backwards until we find the first period
while tokens_new_context and tokens_new_context[-1] not in period_tokens:
while (insertion_point > 0) and (tokens_new_context[insertion_point-1] != period_tokens):
insertion_point -= 1
tokens_new_context = tokens_context[:insertion_point]
tokens_new_context = tokens_context[:insertion_point]

# Once we get there, then add in your needle, and stick the rest of your context in on the other end.
# Now we have a needle in a haystack
Expand All @@ -283,13 +300,17 @@ def get_context_length_in_tokens(self, context):

def read_context_files(self):
context = ""
current_context_length = 0
max_context_length = max(self.context_lengths)
base_dir = os.path.abspath(os.path.dirname(__file__)) # Package directory

while self.get_context_length_in_tokens(context) < max_context_length:
while current_context_length < max_context_length:
for file in glob.glob(os.path.join(base_dir, self.haystack_dir, "*.txt")):
with open(file, 'r') as f:
context += f.read()
file_content = f.read()

context += file_content
current_context_length += self.get_context_length_in_tokens(file_content)
return context

def encode_and_trim(self, context, context_length):
Expand All @@ -310,6 +331,14 @@ def print_start_test_summary(self):
print (f"- Needle: {self.needle.strip()}")
print ("\n\n")

def print_status(self, elapsed_time, context_length, depth_percent, score, response):
print (f"-- Test Summary -- ")
print (f"Duration: {elapsed_time:.1f} seconds")
print (f"Context: {context_length} tokens")
print (f"Depth: {depth_percent}%")
print (f"Score: {score}")
print (f"Response: {response}\n")

def start_test(self):
if self.print_ongoing_status:
self.print_start_test_summary()
Expand Down