Skip to content

Commit afaf8df

Browse files
committed
refactor: revise locomo_eval to make it support llm other than gpt-4o-mini
1 parent bc7236f commit afaf8df

File tree

1 file changed

+42
-15
lines changed

1 file changed

+42
-15
lines changed

evaluation/scripts/locomo/locomo_eval.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import logging
55
import os
6+
import re
67
import time
78

89
import nltk
@@ -47,6 +48,29 @@ class LLMGrade(BaseModel):
4748
llm_reasoning: str = Field(description="Explain why the answer is correct or incorrect.")
4849

4950

51+
def extract_label_json(text: str) -> str | None:
52+
"""
53+
Extracts a JSON object of the form {"label": "VALUE"} from a given text string.
54+
This function is designed to handle cases where the LLM response contains
55+
natural language alongside a final JSON snippet, ensuring robust parsing.
56+
57+
Supports both single and double quotes around the label value.
58+
Ignores surrounding whitespace and formatting.
59+
60+
Returns:
61+
The full matching JSON string (e.g., '{"label": "CORRECT"}') if found.
62+
None if no valid label JSON is found.
63+
"""
64+
# Regex pattern to match: { "label": "value" } with optional whitespace
65+
# Matches both single and double quotes, allows spaces around keys and values
66+
pattern = r'\{\s*"label"\s*:\s*["\']([^"\']*)["\']\s*\}'
67+
match = re.search(pattern, text)
68+
if match:
69+
# Return the complete matched JSON string for safe json.loads()
70+
return match.group(0)
71+
return None
72+
73+
5074
async def locomo_grader(llm_client, question: str, gold_answer: str, response: str) -> bool:
5175
system_prompt = """
5276
You are an expert grader that determines if answers to questions match a gold standard answer
@@ -77,20 +101,23 @@ async def locomo_grader(llm_client, question: str, gold_answer: str, response: s
77101
78102
Just return the label CORRECT or WRONG in a json format with the key as "label".
79103
"""
80-
81-
response = await llm_client.chat.completions.create(
82-
model="gpt-4o-mini",
83-
messages=[
84-
{"role": "system", "content": system_prompt},
85-
{"role": "user", "content": accuracy_prompt},
86-
],
87-
temperature=0,
88-
)
89-
message_content = response.choices[0].message.content
90-
label = json.loads(message_content)["label"]
91-
parsed = LLMGrade(llm_judgment=label, llm_reasoning="")
92-
93-
return parsed.llm_judgment.strip().lower() == "correct"
104+
try:
105+
response = await llm_client.chat.completions.create(
106+
model=os.getenv("EVAL_MODEL", "gpt-4o-mini"),
107+
messages=[
108+
{"role": "system", "content": system_prompt},
109+
{"role": "user", "content": accuracy_prompt},
110+
],
111+
temperature=0,
112+
)
113+
message_content = response.choices[0].message.content
114+
message_content = extract_label_json(text=message_content)
115+
label = json.loads(message_content)["label"]
116+
parsed = LLMGrade(llm_judgment=label, llm_reasoning="")
117+
return parsed.llm_judgment.strip().lower() == "correct"
118+
except Exception as e:
119+
print(f"======== {e}, {response} ===========")
120+
exit()
94121

95122

96123
def calculate_rouge_scores(gold_answer, response):
@@ -284,7 +311,7 @@ async def main(frame, version="default", options=None, num_runs=1, max_workers=4
284311
with open(response_path) as file:
285312
locomo_responses = json.load(file)
286313

287-
num_users = 10
314+
num_users = 2
288315
all_grades = {}
289316

290317
total_responses_count = sum(

0 commit comments

Comments
 (0)