Skip to content

Commit 4a4abca

Browse files
committed
fix the bugs in rule-based baselines, and change the temporal data sorting strategy
1 parent 7119091 commit 4a4abca

File tree

6 files changed

+417
-314
lines changed

6 files changed

+417
-314
lines changed

evaluation/scripts/temporal_locomo/locomo_processor.py

Lines changed: 56 additions & 208 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import json
22
import sys
3-
import traceback
43

54
from collections import defaultdict
65
from concurrent.futures import ThreadPoolExecutor, as_completed
7-
from datetime import datetime
86
from pathlib import Path
97
from time import time
108

@@ -20,10 +18,8 @@
2018
SEARCH_PROMPT_MEMOS,
2119
SEARCH_PROMPT_ZEP,
2220
)
23-
from modules.schemas import RecordingCase
21+
from modules.schemas import ContextUpdateMethod, RecordingCase
2422
from modules.utils import save_evaluation_cases
25-
from openai import OpenAI
26-
from tqdm import tqdm
2723

2824
from memos.log import get_logger
2925

@@ -57,68 +53,24 @@ def __init__(self, args):
5753

5854
self.processed_data_dir = self.result_dir / "processed_data"
5955

60-
# -------------------------------
61-
# Refactor helpers for process_user
62-
# -------------------------------
63-
64-
def _initialize_conv_stats(self):
65-
"""Create a fresh statistics dictionary for a conversation."""
66-
return {
67-
"total_queries": 0,
68-
"can_answer_count": 0,
69-
"cannot_answer_count": 0,
70-
"answer_hit_rate": 0.0,
71-
"response_failure": 0,
72-
"response_count": 0,
73-
}
74-
75-
def _build_day_groups(self, temporal_conv):
76-
"""Build mapping day_id -> qa_pairs from a temporal conversation dict."""
77-
day_groups = {}
78-
for day_id, day_data in temporal_conv.get("days", {}).items():
79-
day_groups[day_id] = day_data.get("qa_pairs", [])
80-
return day_groups
81-
82-
def _build_metadata(self, speaker_a, speaker_b, speaker_a_user_id, speaker_b_user_id, conv_id):
83-
"""Assemble metadata for downstream calls."""
84-
return {
85-
"speaker_a": speaker_a,
86-
"speaker_b": speaker_b,
87-
"speaker_a_user_id": speaker_a_user_id,
88-
"speaker_b_user_id": speaker_b_user_id,
89-
"conv_id": conv_id,
90-
}
91-
92-
def _get_clients(self, frame, speaker_a_user_id, speaker_b_user_id, conv_id, version, top_k):
93-
"""Return (client, reversed_client) according to the target frame."""
94-
reversed_client = None
95-
if frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]:
96-
client = self.get_client_from_storage(frame, speaker_a_user_id, version, top_k=top_k)
97-
reversed_client = self.get_client_from_storage(
98-
frame, speaker_b_user_id, version, top_k=top_k
99-
)
56+
def update_context(self, conv_id, method, **kwargs):
57+
if method == ContextUpdateMethod.DIRECT:
58+
if "cur_context" not in kwargs:
59+
raise ValueError("cur_context is required for DIRECT update method")
60+
cur_context = kwargs["cur_context"]
61+
self.pre_context_cache[conv_id] = cur_context
62+
elif method == ContextUpdateMethod.TEMPLATE:
63+
if "query" not in kwargs or "answer" not in kwargs:
64+
raise ValueError("query and answer are required for TEMPLATE update method")
65+
self._update_context_template(conv_id, kwargs["query"], kwargs["answer"])
10066
else:
101-
client = self.get_client_from_storage(frame, conv_id, version)
102-
return client, reversed_client
103-
104-
def _save_conv_stats(self, conv_id, frame, version, conv_stats, conv_stats_path):
105-
"""Persist per-conversation stats to disk."""
106-
conv_stats_data = {
107-
"conversation_id": conv_id,
108-
"frame": frame,
109-
"version": version,
110-
"statistics": conv_stats,
111-
"timestamp": str(datetime.now()),
112-
}
113-
with open(conv_stats_path, "w") as fw:
114-
json.dump(conv_stats_data, fw, indent=2, ensure_ascii=False)
115-
print(f"Saved conversation stats for {conv_id} to {conv_stats_path}")
67+
raise ValueError(f"Unsupported update method: {method}")
11668

117-
def _write_user_search_results(self, user_search_path, search_results, conv_id):
118-
"""Write per-user search results to a temporary JSON file."""
119-
with open(user_search_path, "w") as fw:
120-
json.dump(dict(search_results), fw, indent=2)
121-
print(f"Save search results {conv_id}")
69+
def _update_context_template(self, conv_id, query, answer):
70+
new_context = f"User: {query}\nAssistant: {answer}\n\n"
71+
if self.pre_context_cache[conv_id] is None:
72+
self.pre_context_cache[conv_id] = ""
73+
self.pre_context_cache[conv_id] += new_context
12274

12375
def _process_single_qa(
12476
self,
@@ -136,24 +88,35 @@ def _process_single_qa(
13688
conv_stats,
13789
):
13890
query = qa.get("question")
91+
gold_answer = qa.get("answer")
13992
qa_category = qa.get("category")
14093
if qa_category == 5:
14194
return None
14295

14396
# Search
144-
context, search_duration_ms = self.search_query(
97+
cur_context, search_duration_ms = self.search_query(
14598
client, query, metadata, frame, reversed_client=reversed_client, top_k=top_k
14699
)
147-
if not context:
100+
if not cur_context:
148101
logger.warning(f"No context found for query: {query[:100]}")
149-
context = ""
102+
cur_context = ""
150103

151104
# Context answerability analysis (for memos_scheduler only)
152-
gold_answer = qa.get("answer")
153105
if self.pre_context_cache[conv_id] is None:
154106
# Update pre-context cache with current context
155-
with self.stats_lock:
156-
self.pre_context_cache[conv_id] = context
107+
if self.frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]:
108+
self.update_context(
109+
conv_id=conv_id,
110+
method=self.context_update_method,
111+
cur_context=cur_context,
112+
)
113+
else:
114+
self.update_context(
115+
conv_id=conv_id,
116+
method=self.context_update_method,
117+
query=query,
118+
answer=gold_answer,
119+
)
157120
return None
158121

159122
can_answer = False
@@ -181,15 +144,9 @@ def _process_single_qa(
181144
)
182145
self.save_stats()
183146

184-
# Update pre-context cache with current context
185-
with self.stats_lock:
186-
self.pre_context_cache[conv_id] = context
187-
188-
self.print_eval_info()
189-
190147
# Generate answer
191148
answer_start = time()
192-
answer = self.locomo_response(frame, oai_client, context, query)
149+
answer = self.locomo_response(frame, oai_client, self.pre_context_cache[conv_id], query)
193150
response_duration_ms = (time() - answer_start) * 1000
194151

195152
# Record case for memos_scheduler
@@ -199,7 +156,7 @@ def _process_single_qa(
199156
conv_id=conv_id,
200157
query=query,
201158
answer=answer,
202-
context=context,
159+
context=cur_context,
203160
pre_context=self.pre_context_cache[conv_id],
204161
can_answer=can_answer,
205162
can_answer_reason=f"Context analysis result: {'can answer' if can_answer else 'cannot answer'}",
@@ -248,146 +205,37 @@ def _process_single_qa(
248205

249206
logger.info(f"Processed question: {query[:100]}")
250207
logger.info(f"Answer: {answer[:100]}")
208+
209+
# Update pre-context cache with current context
210+
with self.stats_lock:
211+
if self.frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]:
212+
self.update_context(
213+
conv_id=conv_id,
214+
method=self.context_update_method,
215+
cur_context=cur_context,
216+
)
217+
else:
218+
self.update_context(
219+
conv_id=conv_id,
220+
method=self.context_update_method,
221+
query=query,
222+
answer=gold_answer,
223+
)
224+
225+
self.print_eval_info()
226+
251227
return {
252228
"question": query,
253229
"answer": answer,
254230
"category": qa_category,
255231
"golden_answer": gold_answer,
256-
"search_context": context,
232+
"search_context": cur_context,
257233
"response_duration_ms": response_duration_ms,
258234
"search_duration_ms": search_duration_ms,
259235
"can_answer_duration_ms": can_answer_duration_ms,
260236
"can_answer": can_answer if frame == "memos_scheduler" else None,
261237
}
262238

263-
def process_user(self, conv_id, locomo_df, frame, version, top_k=20):
264-
user_search_path = self.result_dir / f"tmp/{frame}_locomo_search_results_{conv_id}.json"
265-
user_search_path.parent.mkdir(exist_ok=True, parents=True)
266-
search_results = defaultdict(list)
267-
response_results = defaultdict(list)
268-
conv_stats_path = self.stats_dir / f"{frame}_{version}_conv_{conv_id}_stats.json"
269-
270-
conversation = locomo_df["conversation"].iloc[conv_id]
271-
speaker_a = conversation.get("speaker_a", "speaker_a")
272-
speaker_b = conversation.get("speaker_b", "speaker_b")
273-
274-
# Use temporal_locomo data if available, otherwise fall back to original locomo data
275-
temporal_conv = self.temporal_locomo_data[conv_id]
276-
conv_id = temporal_conv["conversation_id"]
277-
speaker_a_user_id = f"{conv_id}_speaker_a"
278-
speaker_b_user_id = f"{conv_id}_speaker_b"
279-
280-
# Process temporal data by days
281-
day_groups = {}
282-
for day_id, day_data in temporal_conv["days"].items():
283-
day_groups[day_id] = day_data["qa_pairs"]
284-
285-
# Initialize conversation-level statistics
286-
conv_stats = self._initialize_conv_stats()
287-
288-
metadata = self._build_metadata(
289-
speaker_a, speaker_b, speaker_a_user_id, speaker_b_user_id, conv_id
290-
)
291-
292-
client, reversed_client = self._get_clients(
293-
frame, speaker_a_user_id, speaker_b_user_id, conv_id, version, top_k
294-
)
295-
296-
oai_client = OpenAI(api_key=self.openai_api_key, base_url=self.openai_base_url)
297-
298-
with self.stats_lock:
299-
self.pre_context_cache[conv_id] = None
300-
301-
def process_qa(qa):
302-
return self._process_single_qa(
303-
qa,
304-
client=client,
305-
reversed_client=reversed_client,
306-
metadata=metadata,
307-
frame=frame,
308-
version=version,
309-
conv_id=conv_id,
310-
conv_stats_path=conv_stats_path,
311-
oai_client=oai_client,
312-
top_k=top_k,
313-
conv_stats=conv_stats,
314-
)
315-
316-
# ===================================
317-
conv_stats["theoretical_total_queries"] = 0
318-
for day, qa_list in day_groups.items():
319-
conv_stats["theoretical_total_queries"] += len(qa_list) - 1
320-
conv_stats["processing_failure_count"] = 0
321-
print(f"Processing user {conv_id} day {day}")
322-
for qa in tqdm(qa_list, desc=f"Processing user {conv_id} day {day}"):
323-
try:
324-
result = process_qa(qa)
325-
except Exception as e:
326-
logger.error(f"Error: {e}. traceback: {traceback.format_exc()}")
327-
conv_stats["processing_failure_count"] += 1
328-
continue
329-
if result:
330-
context_preview = (
331-
result["search_context"][:20] + "..."
332-
if result["search_context"]
333-
else "No context"
334-
)
335-
if "can_answer" in result:
336-
logger.info("Print can_answer case")
337-
logger.info(
338-
{
339-
"question": result["question"][:100],
340-
"pre context can answer": result["can_answer"],
341-
"answer": result["answer"][:100],
342-
"golden_answer": result["golden_answer"],
343-
"search_context": context_preview[:100],
344-
"search_duration_ms": result["search_duration_ms"],
345-
}
346-
)
347-
348-
search_results[conv_id].append(
349-
{
350-
"question": result["question"],
351-
"context": result["search_context"],
352-
"search_duration_ms": result["search_duration_ms"],
353-
}
354-
)
355-
response_results[conv_id].append(result)
356-
357-
logger.warning(
358-
f"Finished processing user {conv_id} day {day}, data_length: {len(qa_list)}"
359-
)
360-
361-
# recording separate search results
362-
with open(user_search_path, "w") as fw:
363-
json.dump(dict(search_results), fw, indent=2)
364-
print(f"Save search results {conv_id}")
365-
366-
# Dump stats after processing each user
367-
self.save_stats()
368-
369-
return search_results, response_results
370-
371-
def process_user_wrapper(self, args):
372-
"""
373-
Wraps the process_user function to support parallel execution and error handling.
374-
375-
Args:
376-
args: Tuple containing parameters for process_user
377-
378-
Returns:
379-
tuple: Contains user results or error information
380-
"""
381-
idx, locomo_df, frame, version, top_k = args
382-
try:
383-
print(f"Processing user {idx}...")
384-
user_search_results, user_response_results = self.process_user(
385-
idx, locomo_df, frame, version, top_k
386-
)
387-
return (user_search_results, user_response_results, None)
388-
except Exception as e:
389-
return (None, None, (idx, e, traceback.format_exc()))
390-
391239
def run_locomo_processing(self, num_users=10):
392240
load_dotenv()
393241

0 commit comments

Comments
 (0)