Skip to content

Commit cdd9447

Browse files
committed
format code
1 parent 5ed5d56 commit cdd9447

File tree

5 files changed

+34
-61
lines changed

5 files changed

+34
-61
lines changed

evaluation/scripts/locomo/locomo_ingestion.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
1-
import asyncio
21
import os
32
import sys
3+
import argparse
4+
import concurrent.futures
5+
import time
6+
from datetime import datetime, timezone
7+
import pandas as pd
8+
from dotenv import load_dotenv
49

510
ROOT_DIR = os.path.dirname(
611
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
712
)
813
EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts")
914
sys.path.insert(0, ROOT_DIR)
1015
sys.path.insert(0, EVAL_SCRIPTS_DIR)
11-
12-
import argparse
13-
import concurrent.futures
14-
import time
15-
from datetime import datetime, timezone
16-
import pandas as pd
17-
from dotenv import load_dotenv
1816
from prompts import custom_instructions
1917

2018

@@ -64,20 +62,13 @@ def ingest_session(client, session, frame, version, metadata):
6462
client.add(speaker_a_messages, speaker_a_user_id)
6563
client.add(speaker_b_messages, speaker_b_user_id)
6664
elif frame == "memu":
67-
# speaker_a_user_id = metadata['speaker_a']
68-
# speaker_b_user_id = metadata['speaker_b']
69-
# client.agent_id = speaker_b_user_id
7065
client.add(speaker_a_messages, speaker_a_user_id, iso_date)
71-
# client.agent_id = speaker_a_user_id
7266
client.add(speaker_b_messages, speaker_b_user_id, iso_date)
7367
elif frame == "supermemory":
7468
for m in speaker_a_messages:
7569
m["chat_time"] = iso_date
7670
for m in speaker_b_messages:
7771
m["chat_time"] = iso_date
78-
# seems like user_id can not be too long
79-
speaker_a_user_id = f"lcm{conv_idx}a_{version}"
80-
speaker_b_user_id = f"lcm{conv_idx}b_{version}"
8172
client.add(speaker_a_messages, speaker_a_user_id)
8273
client.add(speaker_b_messages, speaker_b_user_id)
8374

@@ -114,11 +105,8 @@ def process_user(conv_idx, frame, locomo_df, version):
114105
client = memobase_client()
115106
all_users = client.client.get_all_users(limit=5000)
116107
for user in all_users:
117-
try:
118-
if user["additional_fields"]["user_id"] in [speaker_a_user_id, speaker_b_user_id]:
119-
client.client.delete_user(user["id"])
120-
except:
121-
pass
108+
if user["additional_fields"]["user_id"] in [speaker_a_user_id, speaker_b_user_id]:
109+
client.client.delete_user(user["id"])
122110
speaker_a_user_id = client.client.add_user({"user_id": speaker_a_user_id})
123111
speaker_b_user_id = client.client.add_user({"user_id": speaker_b_user_id})
124112
elif frame == "memu":
@@ -129,7 +117,6 @@ def process_user(conv_idx, frame, locomo_df, version):
129117
from utils.client import supermemory_client
130118

131119
client = supermemory_client()
132-
133120
sessions_to_process = []
134121
for session_idx in range(max_session_count):
135122
session_key = f"session_{session_idx}"

evaluation/scripts/locomo/locomo_responses.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,6 @@ async def main(frame, version="default"):
112112

113113
os.makedirs("data", exist_ok=True)
114114

115-
# print(all_responses)
116-
117115
with open(response_path, "w") as f:
118116
json.dump(all_responses, f, indent=2)
119117
print("Save response results")

evaluation/scripts/locomo/locomo_search.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
import os
22
import sys
3-
3+
import argparse
4+
import json
5+
from collections import defaultdict
6+
from concurrent.futures import ThreadPoolExecutor, as_completed
7+
from time import time
8+
import pandas as pd
9+
from dotenv import load_dotenv
10+
from tqdm import tqdm
411

512
ROOT_DIR = os.path.dirname(
613
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -9,16 +16,8 @@
916

1017
sys.path.insert(0, ROOT_DIR)
1118
sys.path.insert(0, EVAL_SCRIPTS_DIR)
12-
from prompts import TEMPLATE_MEM0, TEMPLATE_MEMOBASE, TEMPLATE_MEMOS, TEMPLATE_MEM0_GRAPH
1319

14-
import argparse
15-
import json
16-
from collections import defaultdict
17-
from concurrent.futures import ThreadPoolExecutor, as_completed
18-
from time import time
19-
import pandas as pd
20-
from dotenv import load_dotenv
21-
from tqdm import tqdm
20+
from prompts import TEMPLATE_MEM0, TEMPLATE_MEMOBASE, TEMPLATE_MEMOS, TEMPLATE_MEM0_GRAPH
2221

2322

2423
def mem0_search(client, query, speaker_a_user_id, speaker_b_user_id, top_k, speaker_a, speaker_b):
@@ -168,7 +167,7 @@ def supermemory_search(
168167

169168

170169
def search_query(client, query, metadata, frame, version, top_k=20):
171-
conv_id = metadata.get("conv_id")
170+
_conv_id = metadata.get("conv_id")
172171
speaker_a = metadata.get("speaker_a")
173172
speaker_b = metadata.get("speaker_b")
174173
speaker_a_user_id = metadata.get("speaker_a_user_id")
@@ -247,13 +246,10 @@ def process_user(conv_idx, locomo_df, frame, version, top_k=20, num_workers=1):
247246
client = memobase_client()
248247
users = client.client.get_all_users(limit=5000)
249248
for u in users:
250-
try:
251-
if u["additional_fields"]["user_id"] == speaker_a_user_id:
252-
speaker_a_user_id = u["id"]
253-
if u["additional_fields"]["user_id"] == speaker_b_user_id:
254-
speaker_b_user_id = u["id"]
255-
except:
256-
pass
249+
if u["additional_fields"]["user_id"] == speaker_a_user_id:
250+
speaker_a_user_id = u["id"]
251+
if u["additional_fields"]["user_id"] == speaker_b_user_id:
252+
speaker_b_user_id = u["id"]
257253
elif frame == "memu":
258254
from utils.client import memu_client
259255

evaluation/scripts/longmemeval/lme_ingestion.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
import argparse
22
import os
33
import sys
4-
5-
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
64
from concurrent.futures import ThreadPoolExecutor, as_completed
75
from datetime import datetime, timezone
8-
96
import pandas as pd
10-
117
from tqdm import tqdm
8+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
129

1310

1411
def ingest_session(session, date, user_id, session_id, frame, client):
@@ -18,7 +15,7 @@ def ingest_session(session, date, user_id, session_id, frame, client):
1815
messages.append({"role": msg["role"], "content": msg["content"][:8000]})
1916
client.add(messages, user_id, int(date.timestamp()))
2017
elif frame == "memobase":
21-
for idx, msg in enumerate(session):
18+
for _idx, msg in enumerate(session):
2219
messages.append(
2320
{
2421
"role": msg["role"],
@@ -39,11 +36,11 @@ def ingest_session(session, date, user_id, session_id, frame, client):
3936
if messages:
4037
client.add(messages=messages, user_id=user_id, conv_id=session_id)
4138
elif frame == "memu":
42-
for idx, msg in enumerate(session):
39+
for _idx, msg in enumerate(session):
4340
messages.append({"role": msg["role"], "content": msg["content"][:8000]})
4441
client.add(messages, user_id, date.isoformat())
4542
elif frame == "supermemory":
46-
for idx, msg in enumerate(session):
43+
for _idx, msg in enumerate(session):
4744
messages.append(
4845
{
4946
"role": msg["role"],
@@ -84,11 +81,8 @@ def ingest_conv(lme_df, version, conv_idx, frame, success_records, f):
8481
client = memobase_client()
8582
all_users = client.client.get_all_users(limit=5000)
8683
for user in all_users:
87-
try:
88-
if user["additional_fields"]["user_id"] == user_id:
89-
client.client.delete_user(user["id"])
90-
except:
91-
pass
84+
if user["additional_fields"]["user_id"] == user_id:
85+
client.client.delete_user(user["id"])
9286
user_id = client.client.add_user({"user_id": user_id})
9387
elif frame == "memu":
9488
from utils.client import memu_client
@@ -135,8 +129,9 @@ def main(frame, version, num_workers=2):
135129
success_records = []
136130
record_file = f"results/lme/{frame}-{version}/success_records.txt"
137131
if os.path.exists(record_file):
138-
for i in open(record_file, "r").readlines():
139-
success_records.append(i.strip())
132+
with open(record_file, "r") as f:
133+
for i in f.readlines():
134+
success_records.append(i.strip())
140135

141136
f = open(record_file, "a+")
142137

evaluation/scripts/longmemeval/lme_search.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,8 @@ def process_user(lme_df, conv_idx, frame, version, top_k=20):
116116
client = memobase_client()
117117
users = client.client.get_all_users(limit=5000)
118118
for u in users:
119-
try:
120-
if u["additional_fields"]["user_id"] == user_id:
121-
user_id = u["id"]
122-
except:
123-
pass
119+
if u["additional_fields"]["user_id"] == user_id:
120+
user_id = u["id"]
124121
context, duration_ms = memobase_search(client, question, user_id, top_k)
125122
elif frame == "memos-api":
126123
from utils.client import memos_api_client
@@ -196,7 +193,7 @@ def main(frame, version, top_k=20, num_workers=2):
196193
for future in tqdm(
197194
as_completed(future_to_idx), total=num_multi_sessions, desc="📊 Processing users"
198195
):
199-
idx = future_to_idx[future]
196+
_idx = future_to_idx[future]
200197
search_results = future.result()
201198
for user_id, results in search_results.items():
202199
all_search_results[user_id].extend(results)

0 commit comments

Comments
 (0)