|
| 1 | +import argparse |
| 2 | +import asyncio |
| 3 | +import json |
| 4 | +import logging |
| 5 | +import os |
| 6 | +import time |
| 7 | + |
| 8 | +import nltk |
| 9 | +import numpy as np |
| 10 | +import transformers |
| 11 | + |
| 12 | +from bert_score import score as bert_score |
| 13 | +from dotenv import load_dotenv |
| 14 | +from locomo_processor import BASE_DIR |
| 15 | +from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu |
| 16 | +from nltk.translate.meteor_score import meteor_score |
| 17 | +from openai import AsyncOpenAI |
| 18 | +from pydantic import BaseModel, Field |
| 19 | +from rouge_score import rouge_scorer |
| 20 | +from scipy.spatial.distance import cosine |
| 21 | +from sentence_transformers import SentenceTransformer |
| 22 | +from tqdm import tqdm |
| 23 | + |
| 24 | + |
| 25 | +logging.basicConfig(level=logging.CRITICAL) |
| 26 | +transformers.logging.set_verbosity_error() |
| 27 | + |
| 28 | +# Download necessary NLTK resources |
| 29 | +try: |
| 30 | + nltk.download("wordnet", quiet=True) |
| 31 | + nltk.download("punkt", quiet=True) |
| 32 | + print("NLTK resources downloaded successfully.") |
| 33 | +except Exception as e: |
| 34 | + print(f"Warning: Failed to download NLTK resources: {e}") |
| 35 | + |
| 36 | + |
| 37 | +try: |
| 38 | + sentence_model_name = "Qwen/Qwen3-Embedding-0.6B" |
| 39 | + sentence_model = SentenceTransformer(sentence_model_name) |
| 40 | + print(f"SentenceTransformer model : {sentence_model_name} loaded successfully.") |
| 41 | +except Exception as e: |
| 42 | + print(f"Failed to load SentenceTransformer model: {e}") |
| 43 | + sentence_model = None |
| 44 | + |
| 45 | + |
| 46 | +class LLMGrade(BaseModel): |
| 47 | + llm_judgment: str = Field(description="CORRECT or WRONG") |
| 48 | + llm_reasoning: str = Field(description="Explain why the answer is correct or incorrect.") |
| 49 | + |
| 50 | + |
| 51 | +async def locomo_grader(llm_client, question: str, gold_answer: str, response: str) -> bool: |
| 52 | + system_prompt = """ |
| 53 | + You are an expert grader that determines if answers to questions match a gold standard answer |
| 54 | + """ |
| 55 | + |
| 56 | + accuracy_prompt = f""" |
| 57 | + Your task is to label an answer to a question as ’CORRECT’ or ’WRONG’. You will be given the following data: |
| 58 | + (1) a question (posed by one user to another user), |
| 59 | + (2) a ’gold’ (ground truth) answer, |
| 60 | + (3) a generated answer |
| 61 | + which you will score as CORRECT/WRONG. |
| 62 | +
|
| 63 | + The point of the question is to ask about something one user should know about the other user based on their prior conversations. |
| 64 | + The gold answer will usually be a concise and short answer that includes the referenced topic, for example: |
| 65 | + Question: Do you remember what I got the last time I went to Hawaii? |
| 66 | + Gold answer: A shell necklace |
| 67 | + The generated answer might be much longer, but you should be generous with your grading - as long as it touches on the same topic as the gold answer, it should be counted as CORRECT. |
| 68 | +
|
| 69 | + For time related questions, the gold answer will be a specific date, month, year, etc. The generated answer might be much longer or use relative time references (like "last Tuesday" or "next month"), but you should be generous with your grading - as long as it refers to the same date or time period as the gold answer, it should be counted as CORRECT. Even if the format differs (e.g., "May 7th" vs "7 May"), consider it CORRECT if it's the same date. |
| 70 | +
|
| 71 | + Now it’s time for the real question: |
| 72 | + Question: {question} |
| 73 | + Gold answer: {gold_answer} |
| 74 | + Generated answer: {response} |
| 75 | +
|
| 76 | + First, provide a short (one sentence) explanation of your reasoning, then finish with CORRECT or WRONG. |
| 77 | + Do NOT include both CORRECT and WRONG in your response, or it will break the evaluation script. |
| 78 | +
|
| 79 | + Just return the label CORRECT or WRONG in a json format with the key as "label". |
| 80 | + """ |
| 81 | + |
| 82 | + response = await llm_client.chat.completions.create( |
| 83 | + model="gpt-4o-mini", |
| 84 | + messages=[ |
| 85 | + {"role": "system", "content": system_prompt}, |
| 86 | + {"role": "user", "content": accuracy_prompt}, |
| 87 | + ], |
| 88 | + temperature=0, |
| 89 | + ) |
| 90 | + message_content = response.choices[0].message.content |
| 91 | + label = json.loads(message_content)["label"] |
| 92 | + parsed = LLMGrade(llm_judgment=label, llm_reasoning="") |
| 93 | + |
| 94 | + return parsed.llm_judgment.strip().lower() == "correct" |
| 95 | + |
| 96 | + |
| 97 | +def calculate_rouge_scores(gold_answer, response): |
| 98 | + metrics = {"rouge1_f": 0.0, "rouge2_f": 0.0, "rougeL_f": 0.0} |
| 99 | + try: |
| 100 | + scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True) |
| 101 | + rouge_scores = scorer.score(gold_answer, response) |
| 102 | + metrics["rouge1_f"] = rouge_scores["rouge1"].fmeasure |
| 103 | + metrics["rouge2_f"] = rouge_scores["rouge2"].fmeasure |
| 104 | + metrics["rougeL_f"] = rouge_scores["rougeL"].fmeasure |
| 105 | + except Exception as e: |
| 106 | + print(f"Failed to calculate ROUGE scores: {e}") |
| 107 | + return metrics |
| 108 | + |
| 109 | + |
| 110 | +def calculate_bleu_scores(gold_tokens, response_tokens): |
| 111 | + metrics = {"bleu1": 0.0, "bleu2": 0.0, "bleu3": 0.0, "bleu4": 0.0} |
| 112 | + |
| 113 | + try: |
| 114 | + smoothing = SmoothingFunction().method1 |
| 115 | + weights = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (0.33, 0.33, 0.33, 0), (0.25, 0.25, 0.25, 0.25)] |
| 116 | + |
| 117 | + for i, weight in enumerate(weights, 1): |
| 118 | + metrics[f"bleu{i}"] = sentence_bleu( |
| 119 | + [gold_tokens], response_tokens, weights=weight, smoothing_function=smoothing |
| 120 | + ) |
| 121 | + except ZeroDivisionError: |
| 122 | + pass |
| 123 | + except Exception as e: |
| 124 | + print(f"Failed to calculate BLEU scores: {e}") |
| 125 | + |
| 126 | + return metrics |
| 127 | + |
| 128 | + |
| 129 | +def calculate_meteor_score(gold_tokens, response_tokens): |
| 130 | + try: |
| 131 | + return meteor_score([gold_tokens], response_tokens) |
| 132 | + except Exception as e: |
| 133 | + print(f"Failed to calculate METEOR score: {e}") |
| 134 | + return 0.0 |
| 135 | + |
| 136 | + |
| 137 | +def calculate_semantic_similarity(gold_answer, response): |
| 138 | + global sentence_model |
| 139 | + |
| 140 | + try: |
| 141 | + if sentence_model is None: |
| 142 | + sentence_model = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B") |
| 143 | + |
| 144 | + gold_embedding = sentence_model.encode([gold_answer], show_progress_bar=False)[0] |
| 145 | + response_embedding = sentence_model.encode([response], show_progress_bar=False)[0] |
| 146 | + return 1 - cosine(gold_embedding, response_embedding) |
| 147 | + except Exception as e: |
| 148 | + print(f"Failed to calculate semantic similarity: {e}") |
| 149 | + return 0.0 |
| 150 | + |
| 151 | + |
| 152 | +def calculate_f1_score(gold_tokens, response_tokens): |
| 153 | + try: |
| 154 | + gold_set = set(gold_tokens) |
| 155 | + response_set = set(response_tokens) |
| 156 | + |
| 157 | + if len(gold_set) == 0 or len(response_set) == 0: |
| 158 | + return 0.0 |
| 159 | + |
| 160 | + precision = len(gold_set.intersection(response_set)) / len(response_set) |
| 161 | + recall = len(gold_set.intersection(response_set)) / len(gold_set) |
| 162 | + |
| 163 | + if precision + recall > 0: |
| 164 | + return 2 * precision * recall / (precision + recall) |
| 165 | + return 0.0 |
| 166 | + except Exception as e: |
| 167 | + print(f"Failed to calculate F1 score: {e}") |
| 168 | + return 0.0 |
| 169 | + |
| 170 | + |
| 171 | +def calculate_nlp_metrics(gold_answer, response, context, options=None): |
| 172 | + if options is None: |
| 173 | + options = ["lexical", "semantic"] |
| 174 | + |
| 175 | + gold_answer = str(gold_answer) if gold_answer is not None else "" |
| 176 | + response = str(response) if response is not None else "" |
| 177 | + |
| 178 | + metrics = {"context_tokens": len(nltk.word_tokenize(context)) if context else 0} |
| 179 | + |
| 180 | + if "lexical" in options: |
| 181 | + gold_tokens = nltk.word_tokenize(gold_answer.lower()) |
| 182 | + response_tokens = nltk.word_tokenize(response.lower()) |
| 183 | + |
| 184 | + metrics["lexical"] = {} |
| 185 | + metrics["lexical"]["f1"] = calculate_f1_score(gold_tokens, response_tokens) |
| 186 | + metrics["lexical"].update(calculate_rouge_scores(gold_answer, response)) |
| 187 | + metrics["lexical"].update(calculate_bleu_scores(gold_tokens, response_tokens)) |
| 188 | + metrics["lexical"]["meteor"] = calculate_meteor_score(gold_tokens, response_tokens) |
| 189 | + |
| 190 | + if "semantic" in options: |
| 191 | + metrics["semantic"] = {} |
| 192 | + metrics["semantic"]["similarity"] = calculate_semantic_similarity(gold_answer, response) |
| 193 | + _, _, f1 = bert_score( |
| 194 | + [gold_answer], [response], lang="en", rescale_with_baseline=True, verbose=False |
| 195 | + ) |
| 196 | + metrics["semantic"]["bert_f1"] = f1.item() if f1 is not None else 0.0 |
| 197 | + |
| 198 | + return metrics |
| 199 | + |
| 200 | + |
| 201 | +def convert_numpy_types(obj): |
| 202 | + if isinstance(obj, np.number): |
| 203 | + return float(obj) |
| 204 | + elif isinstance(obj, dict): |
| 205 | + return {k: convert_numpy_types(v) for k, v in obj.items()} |
| 206 | + elif isinstance(obj, list): |
| 207 | + return [convert_numpy_types(i) for i in obj] |
| 208 | + else: |
| 209 | + return obj |
| 210 | + |
| 211 | + |
| 212 | +async def process_group_responses(group_id, group_responses, oai_client, options, num_runs: int): |
| 213 | + graded_responses = [] |
| 214 | + |
| 215 | + # Process responses with asyncio for concurrent API calls |
| 216 | + for response in tqdm(group_responses, desc=f"Processing group {group_id}"): |
| 217 | + question = response.get("question") |
| 218 | + answer = response.get("answer") |
| 219 | + ground_truth = response.get("golden_answer") |
| 220 | + category = response.get("category") |
| 221 | + |
| 222 | + context = response.get("search_context", "") |
| 223 | + response_duration_ms = response.get("response_duration_ms", 0.0) |
| 224 | + search_duration_ms = response.get("search_duration_ms", 0.0) |
| 225 | + |
| 226 | + if ground_truth is None: |
| 227 | + continue |
| 228 | + |
| 229 | + grading_tasks = [ |
| 230 | + locomo_grader(oai_client, question, ground_truth, answer) for _ in range(num_runs) |
| 231 | + ] |
| 232 | + judgments = await asyncio.gather(*grading_tasks) |
| 233 | + judgments_dict = {f"judgment_{i + 1}": j for i, j in enumerate(judgments)} |
| 234 | + |
| 235 | + nlp_metrics = calculate_nlp_metrics(ground_truth, answer, context, options) |
| 236 | + |
| 237 | + graded_response = { |
| 238 | + "question": question, |
| 239 | + "answer": answer, |
| 240 | + "golden_answer": ground_truth, |
| 241 | + "category": category, |
| 242 | + "llm_judgments": judgments_dict, |
| 243 | + "nlp_metrics": nlp_metrics, |
| 244 | + "response_duration_ms": response_duration_ms, |
| 245 | + "search_duration_ms": search_duration_ms, |
| 246 | + "total_duration_ms": response_duration_ms + search_duration_ms, |
| 247 | + } |
| 248 | + graded_responses.append(graded_response) |
| 249 | + |
| 250 | + return group_id, graded_responses |
| 251 | + |
| 252 | + |
| 253 | +async def process_single_group(group_id, group_responses, oai_client, options, num_runs): |
| 254 | + try: |
| 255 | + start_time = time.time() |
| 256 | + result = await process_group_responses( |
| 257 | + group_id, group_responses, oai_client, options, num_runs |
| 258 | + ) |
| 259 | + end_time = time.time() |
| 260 | + elapsed_time = round(end_time - start_time, 2) |
| 261 | + print(f"Group {group_id} processed in {elapsed_time} seconds") |
| 262 | + return result |
| 263 | + except Exception as e: |
| 264 | + print(f"Error processing group {group_id}: {e}") |
| 265 | + return group_id, [] |
| 266 | + |
| 267 | + |
| 268 | +async def main(frame, version="default", options=None, num_runs=1, max_workers=4): |
| 269 | + print( |
| 270 | + f"\n=== Starting LoCoMo evaluation for {frame} (version: {version}) with {num_runs} run(s) per question ===" |
| 271 | + ) |
| 272 | + print(f"Using {max_workers} concurrent workers for processing groups") |
| 273 | + |
| 274 | + results_dir = f"{BASE_DIR}/results/locomo/{frame}-{version}" |
| 275 | + response_path = f"{results_dir}/{frame}_locomo_responses.json" |
| 276 | + judged_path = f"{results_dir}/{frame}_locomo_judged.json" |
| 277 | + |
| 278 | + os.makedirs(results_dir, exist_ok=True) |
| 279 | + |
| 280 | + load_dotenv() |
| 281 | + oai_client = AsyncOpenAI( |
| 282 | + api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL") |
| 283 | + ) |
| 284 | + |
| 285 | + with open(response_path) as file: |
| 286 | + locomo_responses = json.load(file) |
| 287 | + |
| 288 | + num_users = 10 |
| 289 | + all_grades = {} |
| 290 | + |
| 291 | + total_responses_count = sum( |
| 292 | + len(locomo_responses.get(f"locomo_exp_user_{i}", [])) for i in range(num_users) |
| 293 | + ) |
| 294 | + print(f"Found {total_responses_count} total responses across {num_users} users to evaluate") |
| 295 | + |
| 296 | + # Create tasks for processing each group |
| 297 | + tasks = [] |
| 298 | + active_users = 0 |
| 299 | + for group_idx in range(num_users): |
| 300 | + group_id = f"locomo_exp_user_{group_idx}" |
| 301 | + group_responses = locomo_responses.get(group_id, []) |
| 302 | + if not group_responses: |
| 303 | + print(f"No responses found for group {group_id}") |
| 304 | + continue |
| 305 | + |
| 306 | + active_users += 1 |
| 307 | + tasks.append(process_single_group(group_id, group_responses, oai_client, options, num_runs)) |
| 308 | + |
| 309 | + print(f"Starting evaluation of {active_users} user groups with responses") |
| 310 | + |
| 311 | + semaphore = asyncio.Semaphore(max_workers) |
| 312 | + |
| 313 | + async def limited_task(task): |
| 314 | + async with semaphore: |
| 315 | + return await task |
| 316 | + |
| 317 | + limited_tasks = [limited_task(task) for task in tasks] |
| 318 | + group_results = await asyncio.gather(*limited_tasks) |
| 319 | + |
| 320 | + for group_id, graded_responses in group_results: |
| 321 | + all_grades[group_id] = graded_responses |
| 322 | + |
| 323 | + print("\n=== Evaluation Complete: Calculating final scores ===") |
| 324 | + |
| 325 | + run_scores = [] |
| 326 | + evaluated_count = 0 |
| 327 | + if num_runs > 0: |
| 328 | + for i in range(1, num_runs + 1): |
| 329 | + judgment_key = f"judgment_{i}" |
| 330 | + current_run_correct_count = 0 |
| 331 | + current_run_total_count = 0 |
| 332 | + for group in all_grades.values(): |
| 333 | + for response in group: |
| 334 | + if judgment_key in response["llm_judgments"]: |
| 335 | + if response["llm_judgments"][judgment_key]: |
| 336 | + current_run_correct_count += 1 |
| 337 | + current_run_total_count += 1 |
| 338 | + |
| 339 | + if current_run_total_count > 0: |
| 340 | + run_accuracy = current_run_correct_count / current_run_total_count |
| 341 | + run_scores.append(run_accuracy) |
| 342 | + |
| 343 | + evaluated_count = current_run_total_count |
| 344 | + |
| 345 | + if evaluated_count > 0: |
| 346 | + mean_of_scores = np.mean(run_scores) |
| 347 | + std_of_scores = np.std(run_scores) |
| 348 | + print(f"LLM-as-a-Judge Mean Score: {mean_of_scores:.4f}") |
| 349 | + print(f"LLM-as-a-Judge Standard Deviation: {std_of_scores:.4f}") |
| 350 | + print(f"(Calculated from {num_runs} separate runs over {evaluated_count} questions)") |
| 351 | + print(f"Individual run scores: {[round(s, 4) for s in run_scores]}") |
| 352 | + else: |
| 353 | + print("No responses were evaluated") |
| 354 | + print("LLM-as-a-Judge score: N/A (0/0)") |
| 355 | + |
| 356 | + all_grades = convert_numpy_types(all_grades) |
| 357 | + with open(judged_path, "w") as f: |
| 358 | + json.dump(all_grades, f, indent=2) |
| 359 | + print(f"Saved detailed evaluation results to {judged_path}") |
| 360 | + |
| 361 | + |
| 362 | +if __name__ == "__main__": |
| 363 | + parser = argparse.ArgumentParser() |
| 364 | + parser.add_argument( |
| 365 | + "--lib", |
| 366 | + type=str, |
| 367 | + default="memos_scheduler", |
| 368 | + choices=["zep", "memos", "memos_scheduler", "mem0", "mem0_graph", "langmem", "openai"], |
| 369 | + help="Specify the memory framework (zep or memos or mem0 or mem0_graph)", |
| 370 | + ) |
| 371 | + parser.add_argument( |
| 372 | + "--version", |
| 373 | + type=str, |
| 374 | + default="v0.2.1", |
| 375 | + help="Version identifier for loading results (e.g., 1010)", |
| 376 | + ) |
| 377 | + parser.add_argument( |
| 378 | + "--num_runs", |
| 379 | + type=int, |
| 380 | + default=3, |
| 381 | + help="Number of times to run the LLM grader for each question", |
| 382 | + ) |
| 383 | + parser.add_argument("--options", nargs="+", default=["lexical", "semantic"]) |
| 384 | + parser.add_argument( |
| 385 | + "--workers", type=int, default=10, help="Number of concurrent workers for processing groups" |
| 386 | + ) |
| 387 | + args = parser.parse_args() |
| 388 | + |
| 389 | + asyncio.run(main(args.lib, args.version, args.options, args.num_runs, args.workers)) |
0 commit comments