|
| 1 | +import sys |
| 2 | +import time |
| 3 | + |
| 4 | +from pathlib import Path |
| 5 | +from typing import TYPE_CHECKING |
| 6 | + |
| 7 | +from evaluation.scripts.temporal_locomo.models.locomo_processor import LocomoProcessor |
| 8 | +from evaluation.scripts.temporal_locomo.modules.constants import ( |
| 9 | + MEMOS_SCHEDULER_MODEL, |
| 10 | +) |
| 11 | +from evaluation.scripts.temporal_locomo.modules.prompts import ( |
| 12 | + SEARCH_PROMPT_MEMOS, |
| 13 | +) |
| 14 | +from evaluation.scripts.temporal_locomo.modules.schemas import ContextUpdateMethod, RecordingCase |
| 15 | +from memos.log import get_logger |
| 16 | + |
| 17 | + |
| 18 | +if TYPE_CHECKING: |
| 19 | + from memos.mem_os.main import MOS |
| 20 | + |
| 21 | +FILE_PATH = Path(__file__).absolute() |
| 22 | +BASE_DIR = FILE_PATH.parent.parent.parent |
| 23 | +sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory |
| 24 | + |
| 25 | +logger = get_logger(__name__) |
| 26 | + |
| 27 | + |
| 28 | +class LocomoProcessorWithTimeEval(LocomoProcessor): |
| 29 | + def __init__(self, args): |
| 30 | + super().__init__(args=args) |
| 31 | + self.time_eval_mode = getattr(self.args, "time_eval_mode", False) |
| 32 | + assert self.args.frame == MEMOS_SCHEDULER_MODEL |
| 33 | + assert self.context_update_method == ContextUpdateMethod.PRE_CONTEXT |
| 34 | + if self.time_eval_mode: |
| 35 | + logger.warning( |
| 36 | + "time_eval_mode is activated. _process_single_qa is replaced by _process_single_qa_for_time_eval" |
| 37 | + ) |
| 38 | + self._process_single_qa = self._process_single_qa_for_time_eval |
| 39 | + |
| 40 | + def memos_scheduler_search( |
| 41 | + self, client, query, conv_id, speaker_a, speaker_b, reversed_client=None, top_k=20 |
| 42 | + ): |
| 43 | + # MemOS full search process and skip the parts of scheduler |
| 44 | + start = time.time() |
| 45 | + client: MOS = client |
| 46 | + |
| 47 | + if not self.scheduler_flag: |
| 48 | + # if not scheduler_flag, search to update working memory |
| 49 | + self.memos_search(client, query, conv_id, speaker_a, speaker_b, reversed_client) |
| 50 | + |
| 51 | + # ========= MemOS Search ========= |
| 52 | + # Search for speaker A |
| 53 | + search_a_results = client.search( |
| 54 | + query=query, |
| 55 | + user_id=conv_id + "_speaker_a", |
| 56 | + install_cube_ids=[conv_id + "_speaker_a"], |
| 57 | + top_k=top_k, |
| 58 | + mode="fine", |
| 59 | + internet_search=False, |
| 60 | + moscube=False, # cube for mos introduction |
| 61 | + session_id=None, |
| 62 | + )["text_mem"] |
| 63 | + search_a_results = [[m.memory for m in one["memories"]] for one in search_a_results] |
| 64 | + search_a_results = [item for sublist in search_a_results for item in sublist] |
| 65 | + |
| 66 | + # Search for speaker B |
| 67 | + search_b_results = client.search( |
| 68 | + query=query, |
| 69 | + user_id=conv_id + "_speaker_b", |
| 70 | + install_cube_ids=[conv_id + "_speaker_b"], |
| 71 | + top_k=top_k, |
| 72 | + mode="fine", |
| 73 | + internet_search=False, |
| 74 | + moscube=False, # cube for mos introduction |
| 75 | + session_id=None, |
| 76 | + )["text_mem"] |
| 77 | + search_b_results = [[m.memory for m in one["memories"]] for one in search_b_results] |
| 78 | + search_b_results = [item for sublist in search_b_results for item in sublist] |
| 79 | + |
| 80 | + speaker_a_context = "" |
| 81 | + for item in search_a_results: |
| 82 | + speaker_a_context += f"{item}\n" |
| 83 | + |
| 84 | + speaker_b_context = "" |
| 85 | + for item in search_b_results: |
| 86 | + speaker_b_context += f"{item}\n" |
| 87 | + |
| 88 | + context = SEARCH_PROMPT_MEMOS.format( |
| 89 | + speaker_1=speaker_a, |
| 90 | + speaker_1_memories=speaker_a_context, |
| 91 | + speaker_2=speaker_b, |
| 92 | + speaker_2_memories=speaker_b_context, |
| 93 | + ) |
| 94 | + |
| 95 | + logger.info(f'query "{query[:100]}", context: {context[:100]}"') |
| 96 | + duration_ms = (time.time() - start) * 1000 |
| 97 | + |
| 98 | + return context, duration_ms |
| 99 | + |
| 100 | + def _process_single_qa_for_time_eval( |
| 101 | + self, |
| 102 | + qa, |
| 103 | + *, |
| 104 | + client, |
| 105 | + reversed_client, |
| 106 | + metadata, |
| 107 | + frame, |
| 108 | + version, |
| 109 | + conv_id, |
| 110 | + conv_stats_path, |
| 111 | + oai_client, |
| 112 | + top_k, |
| 113 | + conv_stats, |
| 114 | + ): |
| 115 | + query = qa.get("question") |
| 116 | + gold_answer = qa.get("answer") |
| 117 | + qa_category = qa.get("category") |
| 118 | + if qa_category == 5: |
| 119 | + return None |
| 120 | + |
| 121 | + # 1. two parallel process, |
| 122 | + # 1. memos search + response |
| 123 | + # 2. pre_memories can answer, true : direct answer false: |
| 124 | + |
| 125 | + # Search |
| 126 | + assert self.args.frame == MEMOS_SCHEDULER_MODEL |
| 127 | + cur_context, search_duration_ms = self.search_query( |
| 128 | + client, query, metadata, frame, reversed_client=reversed_client, top_k=top_k |
| 129 | + ) |
| 130 | + if not cur_context: |
| 131 | + logger.warning(f"No context found for query: {query[:100]}") |
| 132 | + cur_context = "" |
| 133 | + |
| 134 | + # Context answer ability analysis (for memos_scheduler only) |
| 135 | + if self.pre_context_cache[conv_id] is None: |
| 136 | + # Update pre-context cache with current context and return |
| 137 | + self.update_context( |
| 138 | + conv_id=conv_id, |
| 139 | + method=self.context_update_method, |
| 140 | + cur_context=cur_context, |
| 141 | + ) |
| 142 | + |
| 143 | + # ========= MemOS Scheduler update ========= |
| 144 | + _ = client.mem_scheduler.update_working_memory_for_eval( |
| 145 | + query=query, user_id=conv_id + "_speaker_a", top_k=top_k |
| 146 | + ) |
| 147 | + |
| 148 | + _ = client.mem_scheduler.update_working_memory_for_eval( |
| 149 | + query=query, user_id=conv_id + "_speaker_b", top_k=top_k |
| 150 | + ) |
| 151 | + return None |
| 152 | + |
| 153 | + context = self.pre_context_cache[conv_id] |
| 154 | + |
| 155 | + # Generate answer |
| 156 | + answer_start = time.time() |
| 157 | + answer = self.locomo_response(frame, oai_client, context, query) |
| 158 | + response_duration_ms = (time.time() - answer_start) * 1000 |
| 159 | + |
| 160 | + can_answer, can_answer_duration_ms = self.eval_context( |
| 161 | + context=context, query=query, gold_answer=gold_answer, oai_client=oai_client |
| 162 | + ) |
| 163 | + |
| 164 | + # Record case for memos_scheduler |
| 165 | + try: |
| 166 | + recording_case = RecordingCase( |
| 167 | + conv_id=conv_id, |
| 168 | + query=query, |
| 169 | + answer=answer, |
| 170 | + context=cur_context, |
| 171 | + pre_context=self.pre_context_cache[conv_id], |
| 172 | + can_answer=can_answer, |
| 173 | + can_answer_reason=f"Context analysis result: {'can answer' if can_answer else 'cannot answer'}", |
| 174 | + search_duration_ms=search_duration_ms, |
| 175 | + can_answer_duration_ms=can_answer_duration_ms, |
| 176 | + response_duration_ms=response_duration_ms, |
| 177 | + category=int(qa_category) if qa_category is not None else None, |
| 178 | + golden_answer=str(qa.get("answer", "")), |
| 179 | + ) |
| 180 | + if can_answer: |
| 181 | + self.can_answer_cases.append(recording_case) |
| 182 | + else: |
| 183 | + self.cannot_answer_cases.append(recording_case) |
| 184 | + except Exception as e: |
| 185 | + logger.error(f"Error creating RecordingCase: {e}") |
| 186 | + print(f"Error creating RecordingCase: {e}") |
| 187 | + logger.error(f"QA data: {qa}") |
| 188 | + print(f"QA data: {qa}") |
| 189 | + logger.error(f"Query: {query}") |
| 190 | + logger.error(f"Answer: {answer}") |
| 191 | + logger.error( |
| 192 | + f"Golden answer (raw): {qa.get('answer')} (type: {type(qa.get('answer'))})" |
| 193 | + ) |
| 194 | + logger.error(f"Category: {qa_category} (type: {type(qa_category)})") |
| 195 | + logger.error(f"Can answer: {can_answer}") |
| 196 | + raise e |
| 197 | + |
| 198 | + # Update conversation stats and context |
| 199 | + self._update_stats_and_context( |
| 200 | + conv_id=conv_id, |
| 201 | + frame=frame, |
| 202 | + version=version, |
| 203 | + conv_stats=conv_stats, |
| 204 | + conv_stats_path=conv_stats_path, |
| 205 | + query=query, |
| 206 | + answer=answer, |
| 207 | + gold_answer=gold_answer, |
| 208 | + cur_context=cur_context, |
| 209 | + can_answer=can_answer, |
| 210 | + ) |
| 211 | + # ========= MemOS Scheduler update ========= |
| 212 | + _ = client.mem_scheduler.update_working_memory_for_eval( |
| 213 | + query=query, user_id=conv_id + "_speaker_a", top_k=top_k |
| 214 | + ) |
| 215 | + |
| 216 | + _ = client.mem_scheduler.update_working_memory_for_eval( |
| 217 | + query=query, user_id=conv_id + "_speaker_b", top_k=top_k |
| 218 | + ) |
| 219 | + return { |
| 220 | + "question": query, |
| 221 | + "answer": answer, |
| 222 | + "category": qa_category, |
| 223 | + "golden_answer": gold_answer, |
| 224 | + "search_context": cur_context, |
| 225 | + "response_duration_ms": response_duration_ms, |
| 226 | + "search_duration_ms": search_duration_ms, |
| 227 | + "can_answer_duration_ms": can_answer_duration_ms, |
| 228 | + "can_answer": can_answer if frame == MEMOS_SCHEDULER_MODEL else None, |
| 229 | + } |
0 commit comments