11import json
22import sys
3- import traceback
43
54from collections import defaultdict
65from concurrent .futures import ThreadPoolExecutor , as_completed
7- from datetime import datetime
86from pathlib import Path
97from time import time
108
2018 SEARCH_PROMPT_MEMOS ,
2119 SEARCH_PROMPT_ZEP ,
2220)
23- from modules .schemas import RecordingCase
21+ from modules .schemas import ContextUpdateMethod , RecordingCase
2422from modules .utils import save_evaluation_cases
25- from openai import OpenAI
26- from tqdm import tqdm
2723
2824from memos .log import get_logger
2925
@@ -57,68 +53,24 @@ def __init__(self, args):
5753
5854 self .processed_data_dir = self .result_dir / "processed_data"
5955
60- # -------------------------------
61- # Refactor helpers for process_user
62- # -------------------------------
63-
64- def _initialize_conv_stats (self ):
65- """Create a fresh statistics dictionary for a conversation."""
66- return {
67- "total_queries" : 0 ,
68- "can_answer_count" : 0 ,
69- "cannot_answer_count" : 0 ,
70- "answer_hit_rate" : 0.0 ,
71- "response_failure" : 0 ,
72- "response_count" : 0 ,
73- }
74-
75- def _build_day_groups (self , temporal_conv ):
76- """Build mapping day_id -> qa_pairs from a temporal conversation dict."""
77- day_groups = {}
78- for day_id , day_data in temporal_conv .get ("days" , {}).items ():
79- day_groups [day_id ] = day_data .get ("qa_pairs" , [])
80- return day_groups
81-
82- def _build_metadata (self , speaker_a , speaker_b , speaker_a_user_id , speaker_b_user_id , conv_id ):
83- """Assemble metadata for downstream calls."""
84- return {
85- "speaker_a" : speaker_a ,
86- "speaker_b" : speaker_b ,
87- "speaker_a_user_id" : speaker_a_user_id ,
88- "speaker_b_user_id" : speaker_b_user_id ,
89- "conv_id" : conv_id ,
90- }
91-
92- def _get_clients (self , frame , speaker_a_user_id , speaker_b_user_id , conv_id , version , top_k ):
93- """Return (client, reversed_client) according to the target frame."""
94- reversed_client = None
95- if frame in [MEMOS_MODEL , MEMOS_SCHEDULER_MODEL ]:
96- client = self .get_client_from_storage (frame , speaker_a_user_id , version , top_k = top_k )
97- reversed_client = self .get_client_from_storage (
98- frame , speaker_b_user_id , version , top_k = top_k
99- )
56+ def update_context (self , conv_id , method , ** kwargs ):
57+ if method == ContextUpdateMethod .DIRECT :
58+ if "cur_context" not in kwargs :
59+ raise ValueError ("cur_context is required for DIRECT update method" )
60+ cur_context = kwargs ["cur_context" ]
61+ self .pre_context_cache [conv_id ] = cur_context
62+ elif method == ContextUpdateMethod .TEMPLATE :
63+ if "query" not in kwargs or "answer" not in kwargs :
64+ raise ValueError ("query and answer are required for TEMPLATE update method" )
65+ self ._update_context_template (conv_id , kwargs ["query" ], kwargs ["answer" ])
10066 else :
101- client = self .get_client_from_storage (frame , conv_id , version )
102- return client , reversed_client
103-
104- def _save_conv_stats (self , conv_id , frame , version , conv_stats , conv_stats_path ):
105- """Persist per-conversation stats to disk."""
106- conv_stats_data = {
107- "conversation_id" : conv_id ,
108- "frame" : frame ,
109- "version" : version ,
110- "statistics" : conv_stats ,
111- "timestamp" : str (datetime .now ()),
112- }
113- with open (conv_stats_path , "w" ) as fw :
114- json .dump (conv_stats_data , fw , indent = 2 , ensure_ascii = False )
115- print (f"Saved conversation stats for { conv_id } to { conv_stats_path } " )
67+ raise ValueError (f"Unsupported update method: { method } " )
11668
117- def _write_user_search_results (self , user_search_path , search_results , conv_id ):
118- """Write per-user search results to a temporary JSON file."" "
119- with open ( user_search_path , "w" ) as fw :
120- json . dump ( dict ( search_results ), fw , indent = 2 )
121- print ( f"Save search results { conv_id } " )
69+ def _update_context_template (self , conv_id , query , answer ):
70+ new_context = f"User: { query } \n Assistant: { answer } \n \n "
71+ if self . pre_context_cache [ conv_id ] is None :
72+ self . pre_context_cache [ conv_id ] = ""
73+ self . pre_context_cache [ conv_id ] += new_context
12274
12375 def _process_single_qa (
12476 self ,
@@ -136,24 +88,35 @@ def _process_single_qa(
13688 conv_stats ,
13789 ):
13890 query = qa .get ("question" )
91+ gold_answer = qa .get ("answer" )
13992 qa_category = qa .get ("category" )
14093 if qa_category == 5 :
14194 return None
14295
14396 # Search
144- context , search_duration_ms = self .search_query (
97+ cur_context , search_duration_ms = self .search_query (
14598 client , query , metadata , frame , reversed_client = reversed_client , top_k = top_k
14699 )
147- if not context :
100+ if not cur_context :
148101 logger .warning (f"No context found for query: { query [:100 ]} " )
149- context = ""
102+ cur_context = ""
150103
151104 # Context answerability analysis (for memos_scheduler only)
152- gold_answer = qa .get ("answer" )
153105 if self .pre_context_cache [conv_id ] is None :
154106 # Update pre-context cache with current context
155- with self .stats_lock :
156- self .pre_context_cache [conv_id ] = context
107+ if self .frame in [MEMOS_MODEL , MEMOS_SCHEDULER_MODEL ]:
108+ self .update_context (
109+ conv_id = conv_id ,
110+ method = self .context_update_method ,
111+ cur_context = cur_context ,
112+ )
113+ else :
114+ self .update_context (
115+ conv_id = conv_id ,
116+ method = self .context_update_method ,
117+ query = query ,
118+ answer = gold_answer ,
119+ )
157120 return None
158121
159122 can_answer = False
@@ -181,15 +144,9 @@ def _process_single_qa(
181144 )
182145 self .save_stats ()
183146
184- # Update pre-context cache with current context
185- with self .stats_lock :
186- self .pre_context_cache [conv_id ] = context
187-
188- self .print_eval_info ()
189-
190147 # Generate answer
191148 answer_start = time ()
192- answer = self .locomo_response (frame , oai_client , context , query )
149+ answer = self .locomo_response (frame , oai_client , self . pre_context_cache [ conv_id ] , query )
193150 response_duration_ms = (time () - answer_start ) * 1000
194151
195152 # Record case for memos_scheduler
@@ -199,7 +156,7 @@ def _process_single_qa(
199156 conv_id = conv_id ,
200157 query = query ,
201158 answer = answer ,
202- context = context ,
159+ context = cur_context ,
203160 pre_context = self .pre_context_cache [conv_id ],
204161 can_answer = can_answer ,
205162 can_answer_reason = f"Context analysis result: { 'can answer' if can_answer else 'cannot answer' } " ,
@@ -248,146 +205,37 @@ def _process_single_qa(
248205
249206 logger .info (f"Processed question: { query [:100 ]} " )
250207 logger .info (f"Answer: { answer [:100 ]} " )
208+
209+ # Update pre-context cache with current context
210+ with self .stats_lock :
211+ if self .frame in [MEMOS_MODEL , MEMOS_SCHEDULER_MODEL ]:
212+ self .update_context (
213+ conv_id = conv_id ,
214+ method = self .context_update_method ,
215+ cur_context = cur_context ,
216+ )
217+ else :
218+ self .update_context (
219+ conv_id = conv_id ,
220+ method = self .context_update_method ,
221+ query = query ,
222+ answer = gold_answer ,
223+ )
224+
225+ self .print_eval_info ()
226+
251227 return {
252228 "question" : query ,
253229 "answer" : answer ,
254230 "category" : qa_category ,
255231 "golden_answer" : gold_answer ,
256- "search_context" : context ,
232+ "search_context" : cur_context ,
257233 "response_duration_ms" : response_duration_ms ,
258234 "search_duration_ms" : search_duration_ms ,
259235 "can_answer_duration_ms" : can_answer_duration_ms ,
260236 "can_answer" : can_answer if frame == "memos_scheduler" else None ,
261237 }
262238
263- def process_user (self , conv_id , locomo_df , frame , version , top_k = 20 ):
264- user_search_path = self .result_dir / f"tmp/{ frame } _locomo_search_results_{ conv_id } .json"
265- user_search_path .parent .mkdir (exist_ok = True , parents = True )
266- search_results = defaultdict (list )
267- response_results = defaultdict (list )
268- conv_stats_path = self .stats_dir / f"{ frame } _{ version } _conv_{ conv_id } _stats.json"
269-
270- conversation = locomo_df ["conversation" ].iloc [conv_id ]
271- speaker_a = conversation .get ("speaker_a" , "speaker_a" )
272- speaker_b = conversation .get ("speaker_b" , "speaker_b" )
273-
274- # Use temporal_locomo data if available, otherwise fall back to original locomo data
275- temporal_conv = self .temporal_locomo_data [conv_id ]
276- conv_id = temporal_conv ["conversation_id" ]
277- speaker_a_user_id = f"{ conv_id } _speaker_a"
278- speaker_b_user_id = f"{ conv_id } _speaker_b"
279-
280- # Process temporal data by days
281- day_groups = {}
282- for day_id , day_data in temporal_conv ["days" ].items ():
283- day_groups [day_id ] = day_data ["qa_pairs" ]
284-
285- # Initialize conversation-level statistics
286- conv_stats = self ._initialize_conv_stats ()
287-
288- metadata = self ._build_metadata (
289- speaker_a , speaker_b , speaker_a_user_id , speaker_b_user_id , conv_id
290- )
291-
292- client , reversed_client = self ._get_clients (
293- frame , speaker_a_user_id , speaker_b_user_id , conv_id , version , top_k
294- )
295-
296- oai_client = OpenAI (api_key = self .openai_api_key , base_url = self .openai_base_url )
297-
298- with self .stats_lock :
299- self .pre_context_cache [conv_id ] = None
300-
301- def process_qa (qa ):
302- return self ._process_single_qa (
303- qa ,
304- client = client ,
305- reversed_client = reversed_client ,
306- metadata = metadata ,
307- frame = frame ,
308- version = version ,
309- conv_id = conv_id ,
310- conv_stats_path = conv_stats_path ,
311- oai_client = oai_client ,
312- top_k = top_k ,
313- conv_stats = conv_stats ,
314- )
315-
316- # ===================================
317- conv_stats ["theoretical_total_queries" ] = 0
318- for day , qa_list in day_groups .items ():
319- conv_stats ["theoretical_total_queries" ] += len (qa_list ) - 1
320- conv_stats ["processing_failure_count" ] = 0
321- print (f"Processing user { conv_id } day { day } " )
322- for qa in tqdm (qa_list , desc = f"Processing user { conv_id } day { day } " ):
323- try :
324- result = process_qa (qa )
325- except Exception as e :
326- logger .error (f"Error: { e } . traceback: { traceback .format_exc ()} " )
327- conv_stats ["processing_failure_count" ] += 1
328- continue
329- if result :
330- context_preview = (
331- result ["search_context" ][:20 ] + "..."
332- if result ["search_context" ]
333- else "No context"
334- )
335- if "can_answer" in result :
336- logger .info ("Print can_answer case" )
337- logger .info (
338- {
339- "question" : result ["question" ][:100 ],
340- "pre context can answer" : result ["can_answer" ],
341- "answer" : result ["answer" ][:100 ],
342- "golden_answer" : result ["golden_answer" ],
343- "search_context" : context_preview [:100 ],
344- "search_duration_ms" : result ["search_duration_ms" ],
345- }
346- )
347-
348- search_results [conv_id ].append (
349- {
350- "question" : result ["question" ],
351- "context" : result ["search_context" ],
352- "search_duration_ms" : result ["search_duration_ms" ],
353- }
354- )
355- response_results [conv_id ].append (result )
356-
357- logger .warning (
358- f"Finished processing user { conv_id } day { day } , data_length: { len (qa_list )} "
359- )
360-
361- # recording separate search results
362- with open (user_search_path , "w" ) as fw :
363- json .dump (dict (search_results ), fw , indent = 2 )
364- print (f"Save search results { conv_id } " )
365-
366- # Dump stats after processing each user
367- self .save_stats ()
368-
369- return search_results , response_results
370-
371- def process_user_wrapper (self , args ):
372- """
373- Wraps the process_user function to support parallel execution and error handling.
374-
375- Args:
376- args: Tuple containing parameters for process_user
377-
378- Returns:
379- tuple: Contains user results or error information
380- """
381- idx , locomo_df , frame , version , top_k = args
382- try :
383- print (f"Processing user { idx } ..." )
384- user_search_results , user_response_results = self .process_user (
385- idx , locomo_df , frame , version , top_k
386- )
387- return (user_search_results , user_response_results , None )
388- except Exception as e :
389- return (None , None , (idx , e , traceback .format_exc ()))
390-
391239 def run_locomo_processing (self , num_users = 10 ):
392240 load_dotenv ()
393241
0 commit comments