|
3 | 3 | import json |
4 | 4 | import logging |
5 | 5 | import os |
| 6 | +import re |
6 | 7 | import time |
7 | 8 |
|
8 | 9 | import nltk |
@@ -47,6 +48,29 @@ class LLMGrade(BaseModel): |
47 | 48 | llm_reasoning: str = Field(description="Explain why the answer is correct or incorrect.") |
48 | 49 |
|
49 | 50 |
|
| 51 | +def extract_label_json(text: str) -> str | None: |
| 52 | + """ |
| 53 | + Extracts a JSON object of the form {"label": "VALUE"} from a given text string. |
| 54 | + This function is designed to handle cases where the LLM response contains |
| 55 | + natural language alongside a final JSON snippet, ensuring robust parsing. |
| 56 | +
|
| 57 | + Supports both single and double quotes around the label value. |
| 58 | + Ignores surrounding whitespace and formatting. |
| 59 | +
|
| 60 | + Returns: |
| 61 | + The full matching JSON string (e.g., '{"label": "CORRECT"}') if found. |
| 62 | + None if no valid label JSON is found. |
| 63 | + """ |
| 64 | + # Regex pattern to match: { "label": "value" } with optional whitespace |
| 65 | + # Matches both single and double quotes, allows spaces around keys and values |
| 66 | + pattern = r'\{\s*"label"\s*:\s*["\']([^"\']*)["\']\s*\}' |
| 67 | + match = re.search(pattern, text) |
| 68 | + if match: |
| 69 | + # Return the complete matched JSON string for safe json.loads() |
| 70 | + return match.group(0) |
| 71 | + return None |
| 72 | + |
| 73 | + |
50 | 74 | async def locomo_grader(llm_client, question: str, gold_answer: str, response: str) -> bool: |
51 | 75 | system_prompt = """ |
52 | 76 | You are an expert grader that determines if answers to questions match a gold standard answer |
@@ -77,20 +101,23 @@ async def locomo_grader(llm_client, question: str, gold_answer: str, response: s |
77 | 101 |
|
78 | 102 | Just return the label CORRECT or WRONG in a json format with the key as "label". |
79 | 103 | """ |
80 | | - |
81 | | - response = await llm_client.chat.completions.create( |
82 | | - model="gpt-4o-mini", |
83 | | - messages=[ |
84 | | - {"role": "system", "content": system_prompt}, |
85 | | - {"role": "user", "content": accuracy_prompt}, |
86 | | - ], |
87 | | - temperature=0, |
88 | | - ) |
89 | | - message_content = response.choices[0].message.content |
90 | | - label = json.loads(message_content)["label"] |
91 | | - parsed = LLMGrade(llm_judgment=label, llm_reasoning="") |
92 | | - |
93 | | - return parsed.llm_judgment.strip().lower() == "correct" |
| 104 | + try: |
| 105 | + response = await llm_client.chat.completions.create( |
| 106 | + model=os.getenv("EVAL_MODEL", "gpt-4o-mini"), |
| 107 | + messages=[ |
| 108 | + {"role": "system", "content": system_prompt}, |
| 109 | + {"role": "user", "content": accuracy_prompt}, |
| 110 | + ], |
| 111 | + temperature=0, |
| 112 | + ) |
| 113 | + message_content = response.choices[0].message.content |
| 114 | + message_content = extract_label_json(text=message_content) |
| 115 | + label = json.loads(message_content)["label"] |
| 116 | + parsed = LLMGrade(llm_judgment=label, llm_reasoning="") |
| 117 | + return parsed.llm_judgment.strip().lower() == "correct" |
| 118 | + except Exception as e: |
| 119 | + print(f"======== {e}, {response} ===========") |
| 120 | + exit() |
94 | 121 |
|
95 | 122 |
|
96 | 123 | def calculate_rouge_scores(gold_answer, response): |
@@ -284,7 +311,7 @@ async def main(frame, version="default", options=None, num_runs=1, max_workers=4 |
284 | 311 | with open(response_path) as file: |
285 | 312 | locomo_responses = json.load(file) |
286 | 313 |
|
287 | | - num_users = 10 |
| 314 | + num_users = 2 |
288 | 315 | all_grades = {} |
289 | 316 |
|
290 | 317 | total_responses_count = sum( |
|
0 commit comments