99
1010from bert_score import score as bert_score
1111from dotenv import load_dotenv
12- from modules .locomo_eval_module import LocomoEvalModelModules
1312from nltk .translate .bleu_score import SmoothingFunction , sentence_bleu
1413from nltk .translate .meteor_score import meteor_score
1514from openai import AsyncOpenAI
1918from sentence_transformers import SentenceTransformer
2019from tqdm import tqdm
2120
21+ from evaluation .scripts .temporal_locomo .modules .locomo_eval_module import LocomoEvalModelModules
2222from memos .log import get_logger
2323
2424
@@ -281,33 +281,64 @@ def __init__(self, args):
281281 api_key = os .getenv ("OPENAI_API_KEY" ), base_url = os .getenv ("OPENAI_BASE_URL" )
282282 )
283283
284- async def run (self ):
285- print (
286- f"\n === Starting LoCoMo evaluation for { self .frame } (version: { self .version } ) with { self .num_runs } run(s) per question ==="
287- )
288- print (f"Using { self .max_workers } concurrent workers for processing groups" )
284+ def _load_response_data (self ):
285+ """
286+ Load response data from the response path file.
289287
288+ Returns:
289+ dict: The loaded response data
290+ """
290291 with open (self .response_path ) as file :
291- locomo_responses = json .load (file )
292+ return json .load (file )
292293
293- num_users = 10
294+ def _load_existing_evaluation_results (self ):
295+ """
296+ Attempt to load existing evaluation results from the judged path.
297+ If the file doesn't exist or there's an error loading it, return an empty dict.
298+
299+ Returns:
300+ dict: Existing evaluation results or empty dict if none available
301+ """
294302 all_grades = {}
303+ try :
304+ if os .path .exists (self .judged_path ):
305+ with open (self .judged_path ) as f :
306+ all_grades = json .load (f )
307+ print (f"Loaded existing evaluation results from { self .judged_path } " )
308+ except Exception as e :
309+ print (f"Error loading existing evaluation results: { e } " )
295310
296- total_responses_count = sum (
297- len (locomo_responses .get (f"locomo_exp_user_{ i } " , [])) for i in range (num_users )
298- )
299- print (f"Found { total_responses_count } total responses across { num_users } users to evaluate" )
311+ return all_grades
312+
313+ def _create_evaluation_tasks (self , locomo_responses , all_grades , num_users ):
314+ """
315+ Create evaluation tasks for groups that haven't been evaluated yet.
316+
317+ Args:
318+ locomo_responses (dict): The loaded response data
319+ all_grades (dict): Existing evaluation results
320+ num_users (int): Number of user groups to process
300321
301- # Create tasks for processing each group
322+ Returns:
323+ tuple: (tasks list, active users count)
324+ """
302325 tasks = []
303326 active_users = 0
327+
304328 for group_idx in range (num_users ):
305329 group_id = f"locomo_exp_user_{ group_idx } "
306330 group_responses = locomo_responses .get (group_id , [])
331+
307332 if not group_responses :
308333 print (f"No responses found for group { group_id } " )
309334 continue
310335
336+ # Skip groups that already have evaluation results
337+ if all_grades .get (group_id ):
338+ print (f"Skipping group { group_id } as it already has evaluation results" )
339+ active_users += 1
340+ continue
341+
311342 active_users += 1
312343 tasks .append (
313344 process_single_group (
@@ -319,29 +350,50 @@ async def run(self):
319350 )
320351 )
321352
322- print (f"Starting evaluation of { active_users } user groups with responses" )
353+ return tasks , active_users
354+
355+ async def _process_tasks (self , tasks ):
356+ """
357+ Process evaluation tasks with concurrency control.
358+
359+ Args:
360+ tasks (list): List of tasks to process
361+
362+ Returns:
363+ list: Results from processing all tasks
364+ """
365+ if not tasks :
366+ return []
323367
324368 semaphore = asyncio .Semaphore (self .max_workers )
325369
326370 async def limited_task (task ):
371+ """Helper function to limit concurrent task execution"""
327372 async with semaphore :
328373 return await task
329374
330375 limited_tasks = [limited_task (task ) for task in tasks ]
331- group_results = await asyncio .gather (* limited_tasks )
376+ return await asyncio .gather (* limited_tasks )
332377
333- for group_id , graded_responses in group_results :
334- all_grades [group_id ] = graded_responses
378+ def _calculate_scores (self , all_grades ):
379+ """
380+ Calculate evaluation scores based on all grades.
335381
336- print ("\n === Evaluation Complete: Calculating final scores ===" )
382+ Args:
383+ all_grades (dict): The complete evaluation results
337384
385+ Returns:
386+ tuple: (run_scores, evaluated_count)
387+ """
338388 run_scores = []
339389 evaluated_count = 0
390+
340391 if self .num_runs > 0 :
341392 for i in range (1 , self .num_runs + 1 ):
342393 judgment_key = f"judgment_{ i } "
343394 current_run_correct_count = 0
344395 current_run_total_count = 0
396+
345397 for group in all_grades .values ():
346398 for response in group :
347399 if judgment_key in response ["llm_judgments" ]:
@@ -355,6 +407,16 @@ async def limited_task(task):
355407
356408 evaluated_count = current_run_total_count
357409
410+ return run_scores , evaluated_count
411+
412+ def _report_scores (self , run_scores , evaluated_count ):
413+ """
414+ Report evaluation scores to the console.
415+
416+ Args:
417+ run_scores (list): List of accuracy scores for each run
418+ evaluated_count (int): Number of evaluated responses
419+ """
358420 if evaluated_count > 0 :
359421 mean_of_scores = np .mean (run_scores )
360422 std_of_scores = np .std (run_scores )
@@ -368,11 +430,63 @@ async def limited_task(task):
368430 print ("No responses were evaluated" )
369431 print ("LLM-as-a-Judge score: N/A (0/0)" )
370432
433+ def _save_results (self , all_grades ):
434+ """
435+ Save evaluation results to the judged path file.
436+
437+ Args:
438+ all_grades (dict): The complete evaluation results to save
439+ """
371440 all_grades = convert_numpy_types (all_grades )
372441 with open (self .judged_path , "w" ) as f :
373442 json .dump (all_grades , f , indent = 2 )
374443 print (f"Saved detailed evaluation results to { self .judged_path } " )
375444
445+ async def run (self ):
446+ """
447+ Main execution method for the LoCoMo evaluation process.
448+ This method orchestrates the entire evaluation workflow:
449+ 1. Loads existing evaluation results if available
450+ 2. Processes only groups that haven't been evaluated yet
451+ 3. Calculates and reports final evaluation scores
452+ """
453+ print (
454+ f"\n === Starting LoCoMo evaluation for { self .frame } (version: { self .version } ) with { self .num_runs } run(s) per question ==="
455+ )
456+ print (f"Using { self .max_workers } concurrent workers for processing groups" )
457+
458+ # Load response data and existing evaluation results
459+ locomo_responses = self ._load_response_data ()
460+ all_grades = self ._load_existing_evaluation_results ()
461+
462+ # Count total responses for reporting
463+ num_users = 10
464+ total_responses_count = sum (
465+ len (locomo_responses .get (f"locomo_exp_user_{ i } " , [])) for i in range (num_users )
466+ )
467+ print (f"Found { total_responses_count } total responses across { num_users } users to evaluate" )
468+
469+ # Create tasks only for groups that haven't been evaluated yet
470+ tasks , active_users = self ._create_evaluation_tasks (locomo_responses , all_grades , num_users )
471+ print (
472+ f"Starting evaluation of { len (tasks )} user groups with responses (out of { active_users } active users)"
473+ )
474+
475+ # Process tasks and update all_grades with results
476+ if tasks :
477+ group_results = await self ._process_tasks (tasks )
478+ for group_id , graded_responses in group_results :
479+ all_grades [group_id ] = graded_responses
480+
481+ print ("\n === Evaluation Complete: Calculating final scores ===" )
482+
483+ # Calculate and report scores
484+ run_scores , evaluated_count = self ._calculate_scores (all_grades )
485+ self ._report_scores (run_scores , evaluated_count )
486+
487+ # Save results
488+ self ._save_results (all_grades )
489+
376490
377491if __name__ == "__main__" :
378492 parser = argparse .ArgumentParser ()
0 commit comments