Skip to content

Commit 391b422

Browse files
committed
fix bugs in temporal locomo codes in evaluation
1 parent 020e6c6 commit 391b422

File tree

12 files changed

+608
-1336
lines changed

12 files changed

+608
-1336
lines changed

evaluation/scripts/temporal_locomo/locomo_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ async def limited_task(task):
371371
parser.add_argument(
372372
"--version",
373373
type=str,
374-
default="v0.2.1",
374+
default="v1.0.1",
375375
help="Version identifier for loading results (e.g., 1010)",
376376
)
377377
parser.add_argument(

evaluation/scripts/temporal_locomo/locomo_ingestion.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,22 @@ def ingest_session(self, client, session, frame, metadata, revised_client=None):
3535
date_format = "%I:%M %p on %d %B, %Y UTC"
3636
date_string = datetime.strptime(session_date, date_format).replace(tzinfo=timezone.utc)
3737
iso_date = date_string.isoformat()
38-
conv_idx = metadata["conv_idx"]
39-
conv_id = "locomo_exp_user_" + str(conv_idx)
38+
conv_id = metadata["conv_id"]
39+
conv_id = "locomo_exp_user_" + str(conv_id)
4040
dt = datetime.fromisoformat(iso_date)
4141
timestamp = int(dt.timestamp())
4242
print(f"Processing conv {conv_id}, session {metadata['session_key']}")
4343
start_time = time.time()
44+
print_once = True # Print example only once per session
4445

4546
if frame == ZEP_MODEL:
4647
for chat in tqdm(session, desc=f"{metadata['session_key']}"):
4748
data = chat.get("speaker") + ": " + chat.get("text")
48-
print({"context": data, "conv_id": conv_id, "created_at": iso_date})
49+
50+
# Print example only once per session
51+
if print_once:
52+
print({"context": data, "conv_id": conv_id, "created_at": iso_date})
53+
print_once = False
4954

5055
# Check if the group exists, if not create it
5156
groups = client.group.get_all_groups()
@@ -84,7 +89,10 @@ def ingest_session(self, client, session, frame, metadata, revised_client=None):
8489
f"Unknown speaker {chat.get('speaker')} in session {metadata['session_key']}"
8590
)
8691

87-
print({"context": data, "conv_id": conv_id, "created_at": iso_date})
92+
# Print example only once per session
93+
if print_once:
94+
print({"context": data, "conv_id": conv_id, "created_at": iso_date})
95+
print_once = False
8896

8997
speaker_a_user_id = conv_id + "_speaker_a"
9098
speaker_b_user_id = conv_id + "_speaker_b"
@@ -119,7 +127,10 @@ def ingest_session(self, client, session, frame, metadata, revised_client=None):
119127
f"Unknown speaker {chat.get('speaker')} in session {metadata['session_key']}"
120128
)
121129

122-
print({"context": data, "conv_id": conv_id, "created_at": iso_date})
130+
# Print example only once per session
131+
if print_once:
132+
print({"context": data, "conv_id": conv_id, "created_at": iso_date})
133+
print_once = False
123134

124135
for i in range(0, len(messages), 2):
125136
batch_messages = messages[i : i + 2]
@@ -162,40 +173,45 @@ def ingest_session(self, client, session, frame, metadata, revised_client=None):
162173

163174
return elapsed_time
164175

165-
def process_user_for_ingestion(self, conv_idx, frame, locomo_df, version, num_workers=1):
176+
def process_user_for_ingestion(self, conv_id, frame, locomo_df, version, num_workers=1):
166177
try:
167178
# Check if locomo_df is empty or doesn't have the required columns
168179
if locomo_df.empty or "conversation" not in locomo_df.columns:
169180
logger.warning(
170-
f"Skipping user {conv_idx}: locomo_df is empty or missing 'conversation' column"
181+
f"Skipping user {conv_id}: locomo_df is empty or missing 'conversation' column"
171182
)
172183
return 0
173184

174-
conversation = locomo_df["conversation"].iloc[conv_idx]
185+
conversation = locomo_df["conversation"].iloc[conv_id]
175186
max_session_count = 35
176187
start_time = time.time()
177188
total_session_time = 0
178189
valid_sessions = 0
179190

180191
revised_client = None
181192
if frame == "zep":
182-
client = self.get_client_for_ingestion("zep")
193+
client = self.get_client_for_ingestion(frame=frame, user_id=None, version="default")
183194
elif frame == "mem0" or frame == "mem0_graph":
184-
client = self.get_client(frame)
185-
client.delete_all(user_id=f"locomo_exp_user_{conv_idx}")
186-
client.delete_all(user_id=f"{conversation.get('speaker_a')}_{conv_idx}")
187-
client.delete_all(user_id=f"{conversation.get('speaker_b')}_{conv_idx}")
195+
client = self.get_client_for_ingestion(frame=frame, user_id=None, version="default")
196+
client.delete_all(user_id=f"locomo_exp_user_{conv_id}")
197+
client.delete_all(user_id=f"{conversation.get('speaker_a')}_{conv_id}")
198+
client.delete_all(user_id=f"{conversation.get('speaker_b')}_{conv_id}")
188199
elif frame in ["memos", "memos_scheduler"]:
189-
conv_id = "locomo_exp_user_" + str(conv_idx)
200+
conv_id = "locomo_exp_user_" + str(conv_id)
190201
speaker_a_user_id = conv_id + "_speaker_a"
191202
speaker_b_user_id = conv_id + "_speaker_b"
192-
client = self.get_client_for_ingestion(speaker_a_user_id)
193-
revised_client = self.get_client_for_ingestion(speaker_b_user_id)
203+
204+
client = self.get_client_for_ingestion(
205+
frame=frame, user_id=speaker_a_user_id, version=version
206+
)
207+
revised_client = self.get_client_for_ingestion(
208+
frame=frame, user_id=speaker_b_user_id, version=version
209+
)
194210
else:
195211
raise NotImplementedError()
196212

197213
sessions_to_process = []
198-
for session_idx in tqdm(range(max_session_count), desc=f"process_user {conv_idx}"):
214+
for session_idx in tqdm(range(max_session_count), desc=f"process_user {conv_id}"):
199215
session_key = f"session_{session_idx}"
200216
session = conversation.get(session_key)
201217
if session is None:
@@ -205,16 +221,16 @@ def process_user_for_ingestion(self, conv_idx, frame, locomo_df, version, num_wo
205221
"session_date": conversation.get(f"session_{session_idx}_date_time") + " UTC",
206222
"speaker_a": conversation.get("speaker_a"),
207223
"speaker_b": conversation.get("speaker_b"),
208-
"speaker_a_user_id": f"{conversation.get('speaker_a')}_{conv_idx}",
209-
"speaker_b_user_id": f"{conversation.get('speaker_b')}_{conv_idx}",
210-
"conv_idx": conv_idx,
224+
"speaker_a_user_id": f"{conversation.get('speaker_a')}_{conv_id}",
225+
"speaker_b_user_id": f"{conversation.get('speaker_b')}_{conv_id}",
226+
"conv_id": conv_id,
211227
"session_key": session_key,
212228
}
213229
sessions_to_process.append((session, metadata))
214230
valid_sessions += 1
215231

216232
print(
217-
f"Processing {valid_sessions} sessions for user {conv_idx} with {num_workers} workers"
233+
f"Processing {valid_sessions} sessions for user {conv_id} with {num_workers} workers"
218234
)
219235
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
220236
futures = {
@@ -229,18 +245,18 @@ def process_user_for_ingestion(self, conv_idx, frame, locomo_df, version, num_wo
229245
try:
230246
session_time = future.result()
231247
total_session_time += session_time
232-
print(f"User {conv_idx}, {session_key} processed in {session_time} seconds")
248+
print(f"User {conv_id}, {session_key} processed in {session_time} seconds")
233249
except Exception as e:
234-
print(f"Error processing user {conv_idx}, session {session_key}: {e!s}")
250+
print(f"Error processing user {conv_id}, session {session_key}: {e!s}")
235251

236252
end_time = time.time()
237253
elapsed_time = round(end_time - start_time, 2)
238-
print(f"User {conv_idx} processed successfully in {elapsed_time} seconds")
254+
print(f"User {conv_id} processed successfully in {elapsed_time} seconds")
239255

240256
return elapsed_time
241257

242258
except Exception as e:
243-
return f"Error processing user {conv_idx}: {e!s}. Exception: {traceback.format_exc()}"
259+
return f"Error processing user {conv_id}: {e!s}. Exception: {traceback.format_exc()}"
244260

245261
def run_ingestion(self):
246262
frame = self.frame

evaluation/scripts/temporal_locomo/locomo_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
parser.add_argument(
1919
"--version",
2020
type=str,
21-
default="v0.2.1",
21+
default="v1.0.1",
2222
help="Version identifier for loading results (e.g., 1010)",
2323
)
2424

0 commit comments

Comments
 (0)