@@ -64,6 +64,8 @@ def __init__(self,
6464 """
6565 if not model_to_test :
6666 raise ValueError ("A language model must be provided to test." )
67+ if not evaluator :
68+ raise ValueError ("An evaluator must be provided to evaluate the model's response." )
6769 if not needle or not haystack_dir or not retrieval_question :
6870 raise ValueError ("Needle, haystack, and retrieval_question must be provided." )
6971
@@ -77,13 +79,20 @@ def __init__(self,
7779 self .save_contexts = save_contexts
7880 self .seconds_to_sleep_between_completions = seconds_to_sleep_between_completions
7981 self .print_ongoing_status = print_ongoing_status
82+
83+ self .context_dir = 'contexts'
84+ self .results_dir = 'results'
85+ self .result_file_format = '{model_name}_len_{context_length}_depth_{depth_percent}'
8086 self .testing_results = []
8187
8288 if context_lengths is None :
8389 if context_lengths_min is None or context_lengths_max is None or context_lengths_num_intervals is None :
8490 raise ValueError ("Either context_lengths_min, context_lengths_max, context_lengths_intervals need to be filled out OR the context_lengths_list needs to be supplied." )
8591 else :
86- self .context_lengths = np .round (np .linspace (context_lengths_min , context_lengths_max , num = context_lengths_num_intervals , endpoint = True )).astype (int )
92+ self .context_lengths = self .get_intervals (context_lengths_min ,
93+ context_lengths_max ,
94+ context_lengths_num_intervals ,
95+ "linear" )
8796 else :
8897 self .context_lengths = context_lengths
8998
@@ -93,13 +102,13 @@ def __init__(self,
93102 if document_depth_percents is None :
94103 if document_depth_percent_min is None or document_depth_percent_max is None or document_depth_percent_intervals is None :
95104 raise ValueError ("Either document_depth_percent_min, document_depth_percent_max, document_depth_percent_intervals need to be filled out OR the document_depth_percents needs to be supplied." )
96-
97- if document_depth_percent_interval_type == 'linear' :
98- self .document_depth_percents = np .round (np .linspace (document_depth_percent_min , document_depth_percent_max , num = document_depth_percent_intervals , endpoint = True )).astype (int )
99- elif document_depth_percent_interval_type == 'sigmoid' :
100- self .document_depth_percents = [self .logistic (x ) for x in np .linspace (document_depth_percent_min , document_depth_percent_max , document_depth_percent_intervals )]
101- else :
105+ if document_depth_percent_interval_type not in ['linear' , 'sigmoid' ]:
102106 raise ValueError ("document_depth_percent_interval_type must be either 'sigmoid' or 'linear' if document_depth_percents is None." )
107+
108+ self .document_depth_percents = self .get_intervals (document_depth_percent_min ,
109+ document_depth_percent_max ,
110+ document_depth_percent_intervals ,
111+ document_depth_percent_interval_type )
103112 else :
104113 self .document_depth_percents = document_depth_percents
105114
@@ -108,6 +117,17 @@ def __init__(self,
108117
109118 self .evaluation_model = evaluator
110119
120+ def get_intervals (self , min_depth , max_depth , num_intervals , interval_type ):
121+ linear_spacing = np .linspace (min_depth , max_depth , num = num_intervals , endpoint = True )
122+
123+ match interval_type :
124+ case 'linear' :
125+ return np .round (linear_spacing ).astype (int )
126+ case 'sigmoid' :
127+ return [self .logistic (x ) for x in linear_spacing ]
128+ case _:
129+ return []
130+
111131 def logistic (self , x , L = 100 , x0 = 50 , k = .1 ):
112132 if x in [0 , 100 ]:
113133 return x
@@ -122,24 +142,20 @@ async def bound_evaluate_and_log(self, sem, *args):
122142 await self .evaluate_and_log (* args )
123143
124144 async def run_test (self ):
125- sem = Semaphore (self .num_concurrent_requests )
126-
127- # Run through each iteration of context_lengths and depths
128- tasks = []
129- for context_length in self .context_lengths :
130- for depth_percent in self .document_depth_percents :
131- task = self .bound_evaluate_and_log (sem , context_length , depth_percent )
132- tasks .append (task )
145+ async with asyncio .TaskGroup () as tg :
146+ sem = Semaphore (self .num_concurrent_requests )
133147
134- # Wait for all tasks to complete
135- await asyncio .gather (* tasks )
148+ # Run through each iteration of context_lengths and depths
149+ for context_length in self .context_lengths :
150+ for depth_percent in self .document_depth_percents :
151+ task = self .bound_evaluate_and_log (sem , context_length , depth_percent )
152+ tg .create_task (task )
136153
137154 async def evaluate_and_log (self , context_length , depth_percent ):
138155 # Checks to see if you've already checked a length/percent/version.
139156 # This helps if the program stop running and you want to restart later
140- if self .save_results :
141- if self .result_exists (context_length , depth_percent ):
142- return
157+ if self .save_results and self .result_exists (context_length , depth_percent ):
158+ return
143159
144160 # Go generate the required length context and place your needle statement in
145161 context = await self .generate_context (context_length , depth_percent )
@@ -159,47 +175,45 @@ async def evaluate_and_log(self, context_length, depth_percent):
159175 score = self .evaluation_model .evaluate_response (response )
160176
161177 results = {
162- # 'context' : context, # Uncomment this line if you'd like to save the context the model was asked to retrieve from. Warning: This will become very large.
163- 'model' : self .model_name ,
164- 'context_length' : int (context_length ),
165- 'depth_percent' : float (depth_percent ),
166- 'version' : self .results_version ,
167- 'needle' : self .needle ,
168- 'model_response' : response ,
169- 'score' : score ,
170- 'test_duration_seconds' : test_elapsed_time ,
171- 'test_timestamp_utc' : datetime .now (timezone .utc ).strftime ('%Y-%m-%d %H:%M:%S%z' )
178+ # 'context': context, # Uncomment this line if you'd like to save the context the model was asked to retrieve from. Warning: This will become very large.
179+ 'model' : self .model_name ,
180+ 'context_length' : int (context_length ),
181+ 'depth_percent' : float (depth_percent ),
182+ 'version' : self .results_version ,
183+ 'needle' : self .needle ,
184+ 'model_response' : response ,
185+ 'score' : score ,
186+ 'test_duration_seconds' : test_elapsed_time ,
187+ 'test_timestamp_utc' : datetime .now (timezone .utc ).strftime ('%Y-%m-%d %H:%M:%S%z' )
172188 }
173189
174190 self .testing_results .append (results )
175191
176192 if self .print_ongoing_status :
177- print (f"-- Test Summary -- " )
178- print (f"Duration: { test_elapsed_time :.1f} seconds" )
179- print (f"Context: { context_length } tokens" )
180- print (f"Depth: { depth_percent } %" )
181- print (f"Score: { score } " )
182- print (f"Response: { response } \n " )
193+ self .print_status (test_elapsed_time , context_length , depth_percent , score , response )
183194
184- context_file_location = f'{ self .model_name .replace ("." , "_" )} _len_{ context_length } _depth_{ int (depth_percent * 100 )} '
195+ parsed_model_name = self .model_name .replace ("." , "_" )
196+ context_file_location = self .result_file_format .format (model_name = parsed_model_name ,
197+ context_length = context_length ,
198+ depth_percent = int (depth_percent ))
185199
186200 if self .save_contexts :
187201 results ['file_name' ] = context_file_location
188202
189203 # Save the context to file for retesting
190- if not os .path .exists ('contexts' ):
191- os .makedirs ('contexts' )
204+ if not os .path .exists (self . context_dir ):
205+ os .makedirs (self . context_dir )
192206
193- with open (f'contexts /{ context_file_location } _context.txt' , 'w' ) as f :
207+ with open (f'{ self . context_dir } /{ context_file_location } _context.txt' , 'w' ) as f :
194208 f .write (context )
195209
196210 if self .save_results :
197211 # Save the context to file for retesting
198- if not os .path .exists ('results' ):
199- os .makedirs ('results' )
212+ if not os .path .exists (self . results_dir ):
213+ os .makedirs (self . results_dir )
200214
201215 # Save the result to file for retesting
202- with open (f'results /{ context_file_location } _results.json' , 'w' ) as f :
216+ with open (f'{ self . results_dir } /{ context_file_location } _results.json' , 'w' ) as f :
203217 json .dump (results , f )
204218
205219 if self .seconds_to_sleep_between_completions :
@@ -210,20 +224,21 @@ def result_exists(self, context_length, depth_percent):
210224 Checks to see if a result has already been evaluated or not
211225 """
212226
213- results_dir = 'results/'
214- if not os .path .exists (results_dir ):
227+ if not os .path .exists (self .results_dir ):
228+ return False
229+
230+ filename = self .result_file_format .format (model_name = self .model_name ,
231+ context_length = context_length ,
232+ depth_percent = depth_percent )
233+ file_path = os .path .join (self .results_dir , f'{ filename } .json' )
234+ if not os .path .exists (file_path ):
215235 return False
216236
217- for filename in os .listdir (results_dir ):
218- if filename .endswith ('.json' ):
219- with open (os .path .join (results_dir , filename ), 'r' ) as f :
220- result = json .load (f )
221- context_length_met = result ['context_length' ] == context_length
222- depth_percent_met = result ['depth_percent' ] == depth_percent
223- version_met = result .get ('version' , 1 ) == self .results_version
224- model_met = result ['model' ] == self .model_name
225- if context_length_met and depth_percent_met and version_met and model_met :
226- return True
237+ with open (file_path , 'r' ) as f :
238+ result = json .load (f )
239+
240+ if result .get ('version' , 1 ) == self .results_version :
241+ return True
227242 return False
228243
229244 async def generate_context (self , context_length , depth_percent ):
@@ -265,9 +280,9 @@ def insert_needle(self, context, depth_percent, context_length):
265280 period_tokens = self .model_to_test .encode_text_to_tokens ('.' )
266281
267282 # Then we iteration backwards until we find the first period
268- while tokens_new_context and tokens_new_context [- 1 ] not in period_tokens :
283+ while ( insertion_point > 0 ) and ( tokens_new_context [insertion_point - 1 ] != period_tokens ) :
269284 insertion_point -= 1
270- tokens_new_context = tokens_context [:insertion_point ]
285+ tokens_new_context = tokens_context [:insertion_point ]
271286
272287 # Once we get there, then add in your needle, and stick the rest of your context in on the other end.
273288 # Now we have a needle in a haystack
@@ -282,12 +297,16 @@ def get_context_length_in_tokens(self, context):
282297
283298 def read_context_files (self ):
284299 context = ""
300+ current_context_length = 0
285301 max_context_length = max (self .context_lengths )
286302
287- while self . get_context_length_in_tokens ( context ) < max_context_length :
303+ while current_context_length < max_context_length :
288304 for file in glob .glob (f"{ self .haystack_dir } /*.txt" ):
289305 with open (file , 'r' ) as f :
290- context += f .read ()
306+ file_content = f .read ()
307+
308+ context += file_content
309+ current_context_length += self .get_context_length_in_tokens (file_content )
291310 return context
292311
293312 def encode_and_trim (self , context , context_length ):
@@ -308,6 +327,14 @@ def print_start_test_summary(self):
308327 print (f"- Needle: { self .needle .strip ()} " )
309328 print ("\n \n " )
310329
330+ def print_status (self , elapsed_time , context_length , depth_percent , score , response ):
331+ print (f"-- Test Summary -- " )
332+ print (f"Duration: { elapsed_time :.1f} seconds" )
333+ print (f"Context: { context_length } tokens" )
334+ print (f"Depth: { depth_percent } %" )
335+ print (f"Score: { score } " )
336+ print (f"Response: { response } \n " )
337+
311338 def start_test (self ):
312339 if self .print_ongoing_status :
313340 self .print_start_test_summary ()
0 commit comments