Skip to content

Commit e024fea

Browse files
committed
modify scheduler evaluation codes
1 parent cb9519d commit e024fea

File tree

3 files changed

+164
-153
lines changed

3 files changed

+164
-153
lines changed

src/memos/mem_scheduler/analyzer/scheduler_for_eval.py

Lines changed: 127 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
import time
4+
5+
from functools import wraps
6+
from typing import TYPE_CHECKING, Any, ClassVar
47

58
from memos.log import get_logger
69
from memos.mem_scheduler.general_scheduler import GeneralScheduler
@@ -24,6 +27,9 @@ class SchedulerForEval(GeneralScheduler):
2427
This class extends GeneralScheduler with evaluation methods.
2528
"""
2629

30+
# Class variable to store timing information for all instances
31+
timer_cache: ClassVar[dict[str, dict[str, Any]]] = {}
32+
2733
def __init__(self, config):
2834
"""
2935
Initialize the SchedulerForEval with the same configuration as GeneralScheduler.
@@ -32,7 +38,79 @@ def __init__(self, config):
3238
config: Configuration object for the scheduler
3339
"""
3440
super().__init__(config)
41+
# Initialize instance timer_cache
42+
self.timer_cache = {}
43+
44+
@staticmethod
45+
def time_it(func_name: str | None = None):
46+
"""
47+
Static method decorator to measure function execution time and store in timer_cache.
48+
49+
Args:
50+
func_name: Custom name for the function in timer_cache. If None, uses function.__name__
51+
"""
52+
53+
def decorator(func):
54+
@wraps(func)
55+
def wrapper(self, *args, **kwargs):
56+
# Get function name
57+
name = func_name or func.__name__
58+
59+
# Start timing
60+
start_time = time.time()
61+
result = func(self, *args, **kwargs)
62+
end_time = time.time()
63+
64+
# Calculate execution time
65+
exec_time = end_time - start_time
66+
67+
# Format time as HH:MM:SS.mmm
68+
hours = int(exec_time // 3600)
69+
minutes = int((exec_time % 3600) // 60)
70+
seconds = exec_time % 60
71+
72+
if hours > 0:
73+
time_str = f"{hours:02d}:{minutes:02d}:{seconds:06.3f}"
74+
else:
75+
time_str = f"{minutes:02d}:{seconds:06.3f}"
76+
77+
# Store in timer_cache
78+
if not hasattr(self, "timer_cache"):
79+
self.timer_cache = {}
80+
81+
self.timer_cache[name] = {
82+
"time_str": time_str,
83+
"seconds": exec_time,
84+
}
85+
86+
logger.info(f"{name} executed in {time_str}")
87+
return result
88+
89+
return wrapper
90+
91+
return decorator
92+
93+
def get_timer_summary(self) -> str:
94+
"""
95+
Get a summary of all timed functions.
3596
97+
Returns:
98+
Formatted string with timing information
99+
"""
100+
if not self.timer_cache:
101+
return "No timing data available."
102+
103+
summary = "=== Timing Summary ===\n"
104+
for func_name, data in self.timer_cache.items():
105+
summary += f"{func_name}: {data['time_str']} (at {data['timestamp']})\n"
106+
107+
return summary
108+
109+
def clear_timer_cache(self):
110+
"""Clear the timer cache."""
111+
self.timer_cache.clear()
112+
113+
@time_it("update_working_memory")
36114
def update_working_memory_for_eval(
37115
self, query: str, user_id: UserID | str, top_k: int
38116
) -> list[str]:
@@ -96,12 +174,12 @@ def update_working_memory_for_eval(
96174
f"search results for {missing_evidences}: {[one.memory for one in results]}"
97175
)
98176
new_candidates.extend(results)
99-
print(
177+
logger.info(
100178
f"missing_evidences: {missing_evidences} and get {len(new_candidates)} new candidate memories."
101179
)
102180
else:
103181
new_candidates = []
104-
print(f"intent_result: {intent_result}. not triggered")
182+
logger.info(f"intent_result: {intent_result}. not triggered")
105183

106184
# rerank
107185
new_order_working_memory = self.replace_working_memory(
@@ -116,24 +194,60 @@ def update_working_memory_for_eval(
116194

117195
return [m.memory for m in new_order_working_memory]
118196

119-
def evaluate_query_with_memories(
120-
self, query: str, memory_texts: list[str], user_id: UserID | str
197+
@time_it("memory_answer_ability")
198+
def evaluate_memory_answer_ability(
199+
self, query: str, memory_texts: list[str], top_k: int = 100
121200
) -> bool:
122201
"""
123202
Use LLM to evaluate whether the given memories can answer the query.
124203
125204
Args:
126205
query: The query string to evaluate
127206
memory_texts: List of memory texts to check against
128-
user_id: User identifier
207+
top_k: Maximum number of memories to consider for evaluation
129208
130209
Returns:
131210
Boolean indicating whether the memories can answer the query
132211
"""
133-
queries = [query]
134-
intent_result = self.monitor.detect_intent(q_list=queries, text_working_memory=memory_texts)
135-
return intent_result["trigger_retrieval"]
212+
# Limit the number of memories to evaluate
213+
limited_memories = memory_texts[:top_k] if memory_texts else []
214+
215+
# Build prompt using the template
216+
prompt = self.monitor.build_prompt(
217+
template_name="memory_answer_ability_evaluation",
218+
query=query,
219+
memory_list="\n".join([f"- {memory}" for memory in limited_memories])
220+
if limited_memories
221+
else "No memories available",
222+
)
223+
224+
# Use the process LLM to generate response
225+
response = self.monitor._process_llm.generate([{"role": "user", "content": prompt}])
226+
227+
try:
228+
# Extract JSON response
229+
from memos.mem_scheduler.utils.misc_utils import extract_json_dict
230+
231+
result = extract_json_dict(response)
232+
233+
# Validate response structure
234+
if "result" in result:
235+
logger.info(
236+
f"Memory answer ability evaluation result: {result['result']}, reason: {result.get('reason', 'No reason provided')}"
237+
)
238+
return result["result"]
239+
else:
240+
logger.warning(f"Invalid response structure from LLM: {result}")
241+
return False
242+
243+
except Exception as e:
244+
logger.error(
245+
f"Failed to parse LLM response for memory answer ability evaluation: {response}. Error: {e}"
246+
)
247+
# Fallback: return False if we can't determine answer ability
248+
return False
136249

250+
@time_it("search_for_eval")
137251
def search_for_eval(
138252
self, query: str, user_id: UserID | str, top_k: int, scheduler_flag: bool = True
139253
) -> tuple[list[str], bool]:
@@ -157,8 +271,8 @@ def search_for_eval(
157271
text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory]
158272

159273
# Use the evaluation function to check if memories can answer the query
160-
can_answer = self.evaluate_query_with_memories(
161-
query=query, memory_texts=text_working_memory, user_id=user_id
274+
can_answer = self.evaluate_memory_answer_ability(
275+
query=query, memory_texts=text_working_memory, top_k=top_k
162276
)
163277
return text_working_memory, can_answer
164278
else:
@@ -168,7 +282,7 @@ def search_for_eval(
168282
)
169283

170284
# Use the evaluation function to check if memories can answer the query
171-
can_answer = self.evaluate_query_with_memories(
172-
query=query, memory_texts=updated_memories, user_id=user_id
285+
can_answer = self.evaluate_memory_answer_ability(
286+
query=query, memory_texts=updated_memories, top_k=top_k
173287
)
174288
return updated_memories, can_answer

src/memos/mem_scheduler/general_scheduler.py

Lines changed: 0 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -37,146 +37,6 @@ def __init__(self, config: GeneralSchedulerConfig):
3737
}
3838
self.dispatcher.register_handlers(handlers)
3939

40-
def update_working_memory_for_eval(
41-
self, query: str, user_id: UserID | str, top_k: int
42-
) -> list[str]:
43-
"""
44-
Update working memory based on query and return the updated memory list.
45-
46-
Args:
47-
query: The query string
48-
user_id: User identifier
49-
top_k: Number of top memories to return
50-
51-
Returns:
52-
List of memory strings from updated working memory
53-
"""
54-
self.monitor.register_query_monitor_if_not_exists(
55-
user_id=user_id, mem_cube_id=self.current_mem_cube_id
56-
)
57-
58-
query_keywords = self.monitor.extract_query_keywords(query=query)
59-
logger.info(
60-
f'Extracted keywords "{query_keywords}" from query "{query}" for user_id={user_id}'
61-
)
62-
63-
item = QueryMonitorItem(
64-
user_id=user_id,
65-
mem_cube_id=self.current_mem_cube_id,
66-
query_text=query,
67-
keywords=query_keywords,
68-
max_keywords=DEFAULT_MAX_QUERY_KEY_WORDS,
69-
)
70-
query_db_manager = self.monitor.query_monitors[user_id][self.current_mem_cube_id]
71-
query_db_manager.obj.put(item=item)
72-
# Sync with database after adding new item
73-
query_db_manager.sync_with_orm()
74-
logger.debug(
75-
f"Queries in monitor for user_id={user_id}, mem_cube_id={self.current_mem_cube_id}: {query_db_manager.obj.get_queries_with_timesort()}"
76-
)
77-
78-
queries = [query]
79-
80-
# recall
81-
mem_cube = self.current_mem_cube
82-
text_mem_base = mem_cube.text_mem
83-
84-
cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory()
85-
text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory]
86-
intent_result = self.monitor.detect_intent(
87-
q_list=queries, text_working_memory=text_working_memory
88-
)
89-
90-
if intent_result["trigger_retrieval"]:
91-
missing_evidences = intent_result["missing_evidences"]
92-
num_evidence = len(missing_evidences)
93-
k_per_evidence = max(1, top_k // max(1, num_evidence))
94-
new_candidates = []
95-
for item in missing_evidences:
96-
logger.info(f"Searching for missing evidence: '{item}' with top_k={k_per_evidence}")
97-
results: list[TextualMemoryItem] = self.retriever.search(
98-
query=item,
99-
mem_cube=mem_cube,
100-
top_k=k_per_evidence,
101-
method=self.search_method,
102-
)
103-
logger.info(
104-
f"Search results for missing evidence '{item}': {[one.memory for one in results]}"
105-
)
106-
new_candidates.extend(results)
107-
print(
108-
f"Missing evidences: {missing_evidences} -> Retrieved {len(new_candidates)} new candidate memories for user_id={user_id}"
109-
)
110-
else:
111-
new_candidates = []
112-
print(
113-
f"Intent detection result: {intent_result} -> Retrieval not triggered for user_id={user_id}"
114-
)
115-
116-
# rerank
117-
new_order_working_memory = self.replace_working_memory(
118-
user_id=user_id,
119-
mem_cube_id=self.current_mem_cube_id,
120-
mem_cube=self.current_mem_cube,
121-
original_memory=cur_working_memory,
122-
new_memory=new_candidates,
123-
)
124-
new_order_working_memory = new_order_working_memory[:top_k]
125-
logger.info(
126-
f"Final working memory size: {len(new_order_working_memory)} memories for user_id={user_id}"
127-
)
128-
129-
return [m.memory for m in new_order_working_memory]
130-
131-
def evaluate_query_with_memories(
132-
self, query: str, memory_texts: list[str], user_id: UserID | str
133-
) -> bool:
134-
"""
135-
Use LLM to evaluate whether the given memories can answer the query.
136-
137-
Args:
138-
query: The query string to evaluate
139-
memory_texts: List of memory texts to check against
140-
user_id: User identifier
141-
142-
Returns:
143-
Boolean indicating whether the memories can answer the query
144-
"""
145-
queries = [query]
146-
intent_result = self.monitor.detect_intent(q_list=queries, text_working_memory=memory_texts)
147-
return intent_result["trigger_retrieval"]
148-
149-
# for evaluation
150-
def search_for_eval(
151-
self, query: str, user_id: UserID | str, top_k: int, scheduler_flag: bool = True
152-
) -> (list[str], bool):
153-
"""
154-
Original search_for_eval function refactored to use the new decomposed functions.
155-
"""
156-
if not scheduler_flag:
157-
# Get current working memory without updating
158-
mem_cube = self.current_mem_cube
159-
text_mem_base = mem_cube.text_mem
160-
cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory()
161-
text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory]
162-
163-
# Use the evaluation function to check if memories can answer the query
164-
can_answer = self.evaluate_query_with_memories(
165-
query=query, memory_texts=text_working_memory, user_id=user_id
166-
)
167-
return text_working_memory, can_answer
168-
else:
169-
# Update working memory and get the result
170-
updated_memories = self.update_working_memory_for_eval(
171-
query=query, user_id=user_id, top_k=top_k
172-
)
173-
174-
# Use the evaluation function to check if memories can answer the query
175-
can_answer = self.evaluate_query_with_memories(
176-
query=query, memory_texts=updated_memories, user_id=user_id
177-
)
178-
return updated_memories, can_answer
179-
18040
def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
18141
"""
18242
Process and handle query trigger messages from the queue.

0 commit comments

Comments
 (0)