11from __future__ import annotations
22
3- from typing import TYPE_CHECKING
3+ import time
4+
5+ from functools import wraps
6+ from typing import TYPE_CHECKING , Any , ClassVar
47
58from memos .log import get_logger
69from memos .mem_scheduler .general_scheduler import GeneralScheduler
@@ -24,6 +27,9 @@ class SchedulerForEval(GeneralScheduler):
2427 This class extends GeneralScheduler with evaluation methods.
2528 """
2629
30+ # Class variable to store timing information for all instances
31+ timer_cache : ClassVar [dict [str , dict [str , Any ]]] = {}
32+
2733 def __init__ (self , config ):
2834 """
2935 Initialize the SchedulerForEval with the same configuration as GeneralScheduler.
@@ -32,7 +38,79 @@ def __init__(self, config):
3238 config: Configuration object for the scheduler
3339 """
3440 super ().__init__ (config )
41+ # Initialize instance timer_cache
42+ self .timer_cache = {}
43+
44+ @staticmethod
45+ def time_it (func_name : str | None = None ):
46+ """
47+ Static method decorator to measure function execution time and store in timer_cache.
48+
49+ Args:
50+ func_name: Custom name for the function in timer_cache. If None, uses function.__name__
51+ """
52+
53+ def decorator (func ):
54+ @wraps (func )
55+ def wrapper (self , * args , ** kwargs ):
56+ # Get function name
57+ name = func_name or func .__name__
58+
59+ # Start timing
60+ start_time = time .time ()
61+ result = func (self , * args , ** kwargs )
62+ end_time = time .time ()
63+
64+ # Calculate execution time
65+ exec_time = end_time - start_time
66+
67+ # Format time as HH:MM:SS.mmm
68+ hours = int (exec_time // 3600 )
69+ minutes = int ((exec_time % 3600 ) // 60 )
70+ seconds = exec_time % 60
71+
72+ if hours > 0 :
73+ time_str = f"{ hours :02d} :{ minutes :02d} :{ seconds :06.3f} "
74+ else :
75+ time_str = f"{ minutes :02d} :{ seconds :06.3f} "
76+
77+ # Store in timer_cache
78+ if not hasattr (self , "timer_cache" ):
79+ self .timer_cache = {}
80+
81+ self .timer_cache [name ] = {
82+ "time_str" : time_str ,
83+ "seconds" : exec_time ,
84+ }
85+
86+ logger .info (f"{ name } executed in { time_str } " )
87+ return result
88+
89+ return wrapper
90+
91+ return decorator
92+
93+ def get_timer_summary (self ) -> str :
94+ """
95+ Get a summary of all timed functions.
3596
97+ Returns:
98+ Formatted string with timing information
99+ """
100+ if not self .timer_cache :
101+ return "No timing data available."
102+
103+ summary = "=== Timing Summary ===\n "
104+ for func_name , data in self .timer_cache .items ():
105+ summary += f"{ func_name } : { data ['time_str' ]} (at { data ['timestamp' ]} )\n "
106+
107+ return summary
108+
109+ def clear_timer_cache (self ):
110+ """Clear the timer cache."""
111+ self .timer_cache .clear ()
112+
113+ @time_it ("update_working_memory" )
36114 def update_working_memory_for_eval (
37115 self , query : str , user_id : UserID | str , top_k : int
38116 ) -> list [str ]:
@@ -96,12 +174,12 @@ def update_working_memory_for_eval(
96174 f"search results for { missing_evidences } : { [one .memory for one in results ]} "
97175 )
98176 new_candidates .extend (results )
99- print (
177+ logger . info (
100178 f"missing_evidences: { missing_evidences } and get { len (new_candidates )} new candidate memories."
101179 )
102180 else :
103181 new_candidates = []
104- print (f"intent_result: { intent_result } . not triggered" )
182+ logger . info (f"intent_result: { intent_result } . not triggered" )
105183
106184 # rerank
107185 new_order_working_memory = self .replace_working_memory (
@@ -116,24 +194,60 @@ def update_working_memory_for_eval(
116194
117195 return [m .memory for m in new_order_working_memory ]
118196
119- def evaluate_query_with_memories (
120- self , query : str , memory_texts : list [str ], user_id : UserID | str
197+ @time_it ("memory_answer_ability" )
198+ def evaluate_memory_answer_ability (
199+ self , query : str , memory_texts : list [str ], top_k : int = 100
121200 ) -> bool :
122201 """
123202 Use LLM to evaluate whether the given memories can answer the query.
124203
125204 Args:
126205 query: The query string to evaluate
127206 memory_texts: List of memory texts to check against
128- user_id: User identifier
207+ top_k: Maximum number of memories to consider for evaluation
129208
130209 Returns:
131210 Boolean indicating whether the memories can answer the query
132211 """
133- queries = [query ]
134- intent_result = self .monitor .detect_intent (q_list = queries , text_working_memory = memory_texts )
135- return intent_result ["trigger_retrieval" ]
212+ # Limit the number of memories to evaluate
213+ limited_memories = memory_texts [:top_k ] if memory_texts else []
214+
215+ # Build prompt using the template
216+ prompt = self .monitor .build_prompt (
217+ template_name = "memory_answer_ability_evaluation" ,
218+ query = query ,
219+ memory_list = "\n " .join ([f"- { memory } " for memory in limited_memories ])
220+ if limited_memories
221+ else "No memories available" ,
222+ )
223+
224+ # Use the process LLM to generate response
225+ response = self .monitor ._process_llm .generate ([{"role" : "user" , "content" : prompt }])
226+
227+ try :
228+ # Extract JSON response
229+ from memos .mem_scheduler .utils .misc_utils import extract_json_dict
230+
231+ result = extract_json_dict (response )
232+
233+ # Validate response structure
234+ if "result" in result :
235+ logger .info (
236+ f"Memory answer ability evaluation result: { result ['result' ]} , reason: { result .get ('reason' , 'No reason provided' )} "
237+ )
238+ return result ["result" ]
239+ else :
240+ logger .warning (f"Invalid response structure from LLM: { result } " )
241+ return False
242+
243+ except Exception as e :
244+ logger .error (
245+ f"Failed to parse LLM response for memory answer ability evaluation: { response } . Error: { e } "
246+ )
247+ # Fallback: return False if we can't determine answer ability
248+ return False
136249
250+ @time_it ("search_for_eval" )
137251 def search_for_eval (
138252 self , query : str , user_id : UserID | str , top_k : int , scheduler_flag : bool = True
139253 ) -> tuple [list [str ], bool ]:
@@ -157,8 +271,8 @@ def search_for_eval(
157271 text_working_memory : list [str ] = [w_m .memory for w_m in cur_working_memory ]
158272
159273 # Use the evaluation function to check if memories can answer the query
160- can_answer = self .evaluate_query_with_memories (
161- query = query , memory_texts = text_working_memory , user_id = user_id
274+ can_answer = self .evaluate_memory_answer_ability (
275+ query = query , memory_texts = text_working_memory , top_k = top_k
162276 )
163277 return text_working_memory , can_answer
164278 else :
@@ -168,7 +282,7 @@ def search_for_eval(
168282 )
169283
170284 # Use the evaluation function to check if memories can answer the query
171- can_answer = self .evaluate_query_with_memories (
172- query = query , memory_texts = updated_memories , user_id = user_id
285+ can_answer = self .evaluate_memory_answer_ability (
286+ query = query , memory_texts = updated_memories , top_k = top_k
173287 )
174288 return updated_memories , can_answer
0 commit comments