Skip to content

Commit 675eeca

Browse files
committed
add new feat of time eval for temporal locomo benchamrk, but this is not completed yet; revise the feat of multiple-thread task race for scheduler dispatcher, and add multi-thread task running functions to dispatcher.
1 parent a2715f5 commit 675eeca

File tree

12 files changed

+729
-192
lines changed

12 files changed

+729
-192
lines changed
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
import sys
2+
import time
3+
4+
from pathlib import Path
5+
from typing import TYPE_CHECKING
6+
7+
from evaluation.scripts.temporal_locomo.models.locomo_processor import LocomoProcessor
8+
from evaluation.scripts.temporal_locomo.modules.constants import (
9+
MEMOS_SCHEDULER_MODEL,
10+
)
11+
from evaluation.scripts.temporal_locomo.modules.prompts import (
12+
SEARCH_PROMPT_MEMOS,
13+
)
14+
from evaluation.scripts.temporal_locomo.modules.schemas import ContextUpdateMethod, RecordingCase
15+
from memos.log import get_logger
16+
17+
18+
if TYPE_CHECKING:
19+
from memos.mem_os.main import MOS
20+
21+
FILE_PATH = Path(__file__).absolute()
22+
BASE_DIR = FILE_PATH.parent.parent.parent
23+
sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
24+
25+
logger = get_logger(__name__)
26+
27+
28+
class LocomoProcessorWithTimeEval(LocomoProcessor):
29+
def __init__(self, args):
30+
super().__init__(args=args)
31+
self.time_eval_mode = getattr(self.args, "time_eval_mode", False)
32+
assert self.args.frame == MEMOS_SCHEDULER_MODEL
33+
assert self.context_update_method == ContextUpdateMethod.PRE_CONTEXT
34+
if self.time_eval_mode:
35+
logger.warning(
36+
"time_eval_mode is activated. _process_single_qa is replaced by _process_single_qa_for_time_eval"
37+
)
38+
self._process_single_qa = self._process_single_qa_for_time_eval
39+
40+
def memos_scheduler_search(
41+
self, client, query, conv_id, speaker_a, speaker_b, reversed_client=None, top_k=20
42+
):
43+
# MemOS full search process and skip the parts of scheduler
44+
start = time.time()
45+
client: MOS = client
46+
47+
if not self.scheduler_flag:
48+
# if not scheduler_flag, search to update working memory
49+
self.memos_search(client, query, conv_id, speaker_a, speaker_b, reversed_client)
50+
51+
# ========= MemOS Search =========
52+
# Search for speaker A
53+
search_a_results = client.search(
54+
query=query,
55+
user_id=conv_id + "_speaker_a",
56+
install_cube_ids=[conv_id + "_speaker_a"],
57+
top_k=top_k,
58+
mode="fine",
59+
internet_search=False,
60+
moscube=False, # cube for mos introduction
61+
session_id=None,
62+
)["text_mem"]
63+
search_a_results = [[m.memory for m in one["memories"]] for one in search_a_results]
64+
search_a_results = [item for sublist in search_a_results for item in sublist]
65+
66+
# Search for speaker B
67+
search_b_results = client.search(
68+
query=query,
69+
user_id=conv_id + "_speaker_b",
70+
install_cube_ids=[conv_id + "_speaker_b"],
71+
top_k=top_k,
72+
mode="fine",
73+
internet_search=False,
74+
moscube=False, # cube for mos introduction
75+
session_id=None,
76+
)["text_mem"]
77+
search_b_results = [[m.memory for m in one["memories"]] for one in search_b_results]
78+
search_b_results = [item for sublist in search_b_results for item in sublist]
79+
80+
speaker_a_context = ""
81+
for item in search_a_results:
82+
speaker_a_context += f"{item}\n"
83+
84+
speaker_b_context = ""
85+
for item in search_b_results:
86+
speaker_b_context += f"{item}\n"
87+
88+
context = SEARCH_PROMPT_MEMOS.format(
89+
speaker_1=speaker_a,
90+
speaker_1_memories=speaker_a_context,
91+
speaker_2=speaker_b,
92+
speaker_2_memories=speaker_b_context,
93+
)
94+
95+
logger.info(f'query "{query[:100]}", context: {context[:100]}"')
96+
duration_ms = (time.time() - start) * 1000
97+
98+
return context, duration_ms
99+
100+
def _process_single_qa_for_time_eval(
101+
self,
102+
qa,
103+
*,
104+
client,
105+
reversed_client,
106+
metadata,
107+
frame,
108+
version,
109+
conv_id,
110+
conv_stats_path,
111+
oai_client,
112+
top_k,
113+
conv_stats,
114+
):
115+
query = qa.get("question")
116+
gold_answer = qa.get("answer")
117+
qa_category = qa.get("category")
118+
if qa_category == 5:
119+
return None
120+
121+
# 1. two parallel process,
122+
# 1. memos search + response
123+
# 2. pre_memories can answer, true : direct answer false:
124+
125+
# Search
126+
assert self.args.frame == MEMOS_SCHEDULER_MODEL
127+
cur_context, search_duration_ms = self.search_query(
128+
client, query, metadata, frame, reversed_client=reversed_client, top_k=top_k
129+
)
130+
if not cur_context:
131+
logger.warning(f"No context found for query: {query[:100]}")
132+
cur_context = ""
133+
134+
# Context answer ability analysis (for memos_scheduler only)
135+
if self.pre_context_cache[conv_id] is None:
136+
# Update pre-context cache with current context and return
137+
self.update_context(
138+
conv_id=conv_id,
139+
method=self.context_update_method,
140+
cur_context=cur_context,
141+
)
142+
143+
# ========= MemOS Scheduler update =========
144+
_ = client.mem_scheduler.update_working_memory_for_eval(
145+
query=query, user_id=conv_id + "_speaker_a", top_k=top_k
146+
)
147+
148+
_ = client.mem_scheduler.update_working_memory_for_eval(
149+
query=query, user_id=conv_id + "_speaker_b", top_k=top_k
150+
)
151+
return None
152+
153+
context = self.pre_context_cache[conv_id]
154+
155+
# Generate answer
156+
answer_start = time.time()
157+
answer = self.locomo_response(frame, oai_client, context, query)
158+
response_duration_ms = (time.time() - answer_start) * 1000
159+
160+
can_answer, can_answer_duration_ms = self.eval_context(
161+
context=context, query=query, gold_answer=gold_answer, oai_client=oai_client
162+
)
163+
164+
# Record case for memos_scheduler
165+
try:
166+
recording_case = RecordingCase(
167+
conv_id=conv_id,
168+
query=query,
169+
answer=answer,
170+
context=cur_context,
171+
pre_context=self.pre_context_cache[conv_id],
172+
can_answer=can_answer,
173+
can_answer_reason=f"Context analysis result: {'can answer' if can_answer else 'cannot answer'}",
174+
search_duration_ms=search_duration_ms,
175+
can_answer_duration_ms=can_answer_duration_ms,
176+
response_duration_ms=response_duration_ms,
177+
category=int(qa_category) if qa_category is not None else None,
178+
golden_answer=str(qa.get("answer", "")),
179+
)
180+
if can_answer:
181+
self.can_answer_cases.append(recording_case)
182+
else:
183+
self.cannot_answer_cases.append(recording_case)
184+
except Exception as e:
185+
logger.error(f"Error creating RecordingCase: {e}")
186+
print(f"Error creating RecordingCase: {e}")
187+
logger.error(f"QA data: {qa}")
188+
print(f"QA data: {qa}")
189+
logger.error(f"Query: {query}")
190+
logger.error(f"Answer: {answer}")
191+
logger.error(
192+
f"Golden answer (raw): {qa.get('answer')} (type: {type(qa.get('answer'))})"
193+
)
194+
logger.error(f"Category: {qa_category} (type: {type(qa_category)})")
195+
logger.error(f"Can answer: {can_answer}")
196+
raise e
197+
198+
# Update conversation stats and context
199+
self._update_stats_and_context(
200+
conv_id=conv_id,
201+
frame=frame,
202+
version=version,
203+
conv_stats=conv_stats,
204+
conv_stats_path=conv_stats_path,
205+
query=query,
206+
answer=answer,
207+
gold_answer=gold_answer,
208+
cur_context=cur_context,
209+
can_answer=can_answer,
210+
)
211+
# ========= MemOS Scheduler update =========
212+
_ = client.mem_scheduler.update_working_memory_for_eval(
213+
query=query, user_id=conv_id + "_speaker_a", top_k=top_k
214+
)
215+
216+
_ = client.mem_scheduler.update_working_memory_for_eval(
217+
query=query, user_id=conv_id + "_speaker_b", top_k=top_k
218+
)
219+
return {
220+
"question": query,
221+
"answer": answer,
222+
"category": qa_category,
223+
"golden_answer": gold_answer,
224+
"search_context": cur_context,
225+
"response_duration_ms": response_duration_ms,
226+
"search_duration_ms": search_duration_ms,
227+
"can_answer_duration_ms": can_answer_duration_ms,
228+
"can_answer": can_answer if frame == MEMOS_SCHEDULER_MODEL else None,
229+
}

evaluation/scripts/temporal_locomo/modules/base_eval_module.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,18 +59,21 @@ def __init__(self, args):
5959
)
6060
else:
6161
logger.warning(f"Temporal locomo dataset not found at {temporal_locomo_file}")
62+
63+
result_dir_prefix = getattr(self.args, "result_dir_prefix", "")
64+
6265
# Configure result dir; if scheduler disabled and using memos scheduler, mark as ablation
6366
if (
6467
hasattr(self.args, "scheduler_flag")
6568
and self.frame == MEMOS_SCHEDULER_MODEL
6669
and self.args.scheduler_flag is False
6770
):
6871
self.result_dir = Path(
69-
f"{BASE_DIR}/results/temporal_locomo/{self.frame}-{self.version}-ablation/"
72+
f"{BASE_DIR}/results/temporal_locomo/{result_dir_prefix}{self.frame}-{self.version}-ablation/"
7073
)
7174
else:
7275
self.result_dir = Path(
73-
f"{BASE_DIR}/results/temporal_locomo/{self.frame}-{self.version}/"
76+
f"{BASE_DIR}/results/temporal_locomo/{result_dir_prefix}{self.frame}-{self.version}/"
7477
)
7578

7679
if self.context_update_method != ContextUpdateMethod.PRE_CONTEXT:
@@ -96,6 +99,10 @@ def __init__(self, args):
9699
if auth_config_path.exists():
97100
auth_config = AuthConfig.from_local_config(config_path=auth_config_path)
98101

102+
self.openai_api_key = auth_config.openai.api_key
103+
self.openai_base_url = auth_config.openai.base_url
104+
self.openai_chat_model = auth_config.openai.default_model
105+
99106
self.mos_config_data = json.load(self.mos_config_path.open("r", encoding="utf-8"))
100107
self.mem_cube_config_data = json.load(
101108
self.mem_cube_config_path.open("r", encoding="utf-8")
@@ -126,9 +133,6 @@ def __init__(self, args):
126133
auth_config.graph_db.auto_create
127134
)
128135

129-
self.openai_api_key = auth_config.openai.api_key
130-
self.openai_base_url = auth_config.openai.base_url
131-
self.openai_chat_model = auth_config.openai.default_model
132136
else:
133137
print("Please referring to configs-example to provide valid configs.")
134138
exit()

evaluation/scripts/temporal_locomo/modules/client_manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,14 @@ def get_client_from_storage(
146146
scheduler_for_eval.current_mem_cube_id = user_id
147147
scheduler_for_eval.current_user_id = user_id
148148

149+
# set llms to openai api
150+
mos.chat_llm = mos.mem_reader.llm
151+
for cube in mos.mem_cubes.values():
152+
cube.text_mem.dispatcher_llm = mos.mem_reader.llm
153+
cube.text_mem.extractor_llm = mos.mem_reader.llm
154+
149155
# Replace the original scheduler
150156
mos.mem_scheduler = scheduler_for_eval
151-
152157
return mos
153158

154159
def locomo_response(self, frame, llm_client, context: str, question: str) -> str:

0 commit comments

Comments
 (0)