Skip to content

Commit a312b80

Browse files
committed
format:ruff format
1 parent d31a1fb commit a312b80

File tree

4 files changed

+241
-156
lines changed

4 files changed

+241
-156
lines changed

evaluation/scripts/personamem/pm_ingestion.py

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,36 @@ def ingest_session(session, user_id, session_id, frame, client):
2020
pass
2121
for idx, msg in enumerate(session):
2222
print(
23-
f"[{frame}] 💬 Session [{session_id}: [{idx + 1}/{len(session)}] Ingesting message: {msg['role']} - {msg['content'][:50]}...")
24-
client.memory.add(messages=[Message(role=msg["role"], role_type=msg["role"], content=msg["content"], )], )
23+
f"[{frame}] 💬 Session [{session_id}: [{idx + 1}/{len(session)}] Ingesting message: {msg['role']} - {msg['content'][:50]}..."
24+
)
25+
client.memory.add(
26+
messages=[
27+
Message(
28+
role=msg["role"],
29+
role_type=msg["role"],
30+
content=msg["content"],
31+
)
32+
],
33+
)
2534
elif frame == "mem0-local" or frame == "mem0-api":
2635
for idx, msg in enumerate(session):
2736
messages.append({"role": msg["role"], "content": msg["content"]})
2837
print(
29-
f"[{frame}] 📝 Session [{session_id}: [{idx + 1}/{len(session)}] Ingesting message: {msg['role']} - {msg['content'][:50]}...")
38+
f"[{frame}] 📝 Session [{session_id}: [{idx + 1}/{len(session)}] Ingesting message: {msg['role']} - {msg['content'][:50]}..."
39+
)
3040
if frame == "mem0-local":
3141
client.add(messages=messages, user_id=user_id)
3242
elif frame == "mem0-api":
33-
client.add(messages=messages,
34-
user_id=user_id,
35-
session_id=session_id,
36-
version="v2", )
43+
client.add(
44+
messages=messages,
45+
user_id=user_id,
46+
session_id=session_id,
47+
version="v2",
48+
)
3749
print(f"[{frame}] ✅ Session [{session_id}]: Ingested {len(messages)} messages")
3850
elif frame == "memos-local" or frame == "memos-api":
3951
for i in range(0, len(session), 10):
40-
messages = session[i: i + 10]
52+
messages = session[i : i + 10]
4153
client.add(messages=messages, user_id=user_id, conv_id=session_id)
4254
print(f"[{frame}] ✅ Session [{session_id}]: Ingested {len(messages)} messages")
4355

@@ -48,7 +60,7 @@ def build_jsonl_index(jsonl_path):
4860
Assumes each line is a JSON object with a single key-value pair.
4961
"""
5062
index = {}
51-
with open(jsonl_path, 'r', encoding='utf-8') as f:
63+
with open(jsonl_path, "r", encoding="utf-8") as f:
5264
while True:
5365
offset = f.tell()
5466
line = f.readline()
@@ -60,14 +72,14 @@ def build_jsonl_index(jsonl_path):
6072

6173

6274
def load_context_by_id(jsonl_path, offset):
63-
with open(jsonl_path, 'r', encoding='utf-8') as f:
75+
with open(jsonl_path, "r", encoding="utf-8") as f:
6476
f.seek(offset)
6577
item = json.loads(f.readline())
6678
return next(iter(item.values()))
6779

6880

6981
def load_rows(csv_path):
70-
with open(csv_path, mode='r', newline='', encoding='utf-8') as csvfile:
82+
with open(csv_path, mode="r", newline="", encoding="utf-8") as csvfile:
7183
reader = csv.DictReader(csvfile)
7284
for _, row in enumerate(reader, start=1):
7385
row_data = {}
@@ -79,7 +91,7 @@ def load_rows(csv_path):
7991
def load_rows_with_context(csv_path, jsonl_path):
8092
jsonl_index = build_jsonl_index(jsonl_path)
8193

82-
with open(csv_path, mode='r', newline='', encoding='utf-8') as csvfile:
94+
with open(csv_path, mode="r", newline="", encoding="utf-8") as csvfile:
8395
reader = csv.DictReader(csvfile)
8496
prev_sid = None
8597
prev_context = None
@@ -99,13 +111,13 @@ def load_rows_with_context(csv_path, jsonl_path):
99111

100112

101113
def count_csv_rows(csv_path):
102-
with open(csv_path, mode='r', newline='', encoding='utf-8') as f:
114+
with open(csv_path, mode="r", newline="", encoding="utf-8") as f:
103115
return sum(1 for _ in f) - 1
104116

105117

106118
def ingest_conv(row_data, context, version, conv_idx, frame):
107119
end_index_in_shared_context = row_data["end_index_in_shared_context"]
108-
context = context[:int(end_index_in_shared_context)]
120+
context = context[: int(end_index_in_shared_context)]
109121
user_id = f"pm_exper_user_{conv_idx}_{version}"
110122
print(f"👤 User ID: {user_id}")
111123
print("\n" + "=" * 80)
@@ -157,7 +169,13 @@ def ingest_conv(row_data, context, version, conv_idx, frame):
157169
print(f"📊 Total sessions to ingest: {len(sessions)}")
158170

159171
for idx, session in enumerate(sessions):
160-
ingest_session(session=session, user_id=user_id, session_id=idx, frame=frame, client=client, )
172+
ingest_session(
173+
session=session,
174+
user_id=user_id,
175+
session_id=idx,
176+
frame=frame,
177+
client=client,
178+
)
161179
print(f"✅ Ingestion of conversation {conv_idx} completed")
162180
print("=" * 80)
163181

@@ -180,16 +198,25 @@ def main(frame, version, num_workers=2):
180198

181199
with ThreadPoolExecutor(max_workers=num_workers) as executor:
182200
future_to_idx = {
183-
executor.submit(ingest_conv, row_data=row_data, context=context, version=version, conv_idx=idx,
184-
frame=frame, ): idx
185-
for idx, (row_data, context) in enumerate(all_data)}
186-
187-
for future in tqdm(as_completed(future_to_idx), total=len(future_to_idx), desc="Processing conversations"):
201+
executor.submit(
202+
ingest_conv,
203+
row_data=row_data,
204+
context=context,
205+
version=version,
206+
conv_idx=idx,
207+
frame=frame,
208+
): idx
209+
for idx, (row_data, context) in enumerate(all_data)
210+
}
211+
212+
for future in tqdm(
213+
as_completed(future_to_idx), total=len(future_to_idx), desc="Processing conversations"
214+
):
188215
idx = future_to_idx[future]
189216
try:
190217
future.result()
191218
except Exception as exc:
192-
print(f'\n❌ Conversation {idx} generated an exception: {exc}')
219+
print(f"\n❌ Conversation {idx} generated an exception: {exc}")
193220

194221
end_time = datetime.now()
195222
elapsed_time = end_time - start_time
@@ -205,10 +232,18 @@ def main(frame, version, num_workers=2):
205232

206233
if __name__ == "__main__":
207234
parser = argparse.ArgumentParser(description="PersonaMem Ingestion Script")
208-
parser.add_argument("--lib", type=str, choices=["mem0-local", "mem0-api", "memos-local", "memos-api", "zep"],
209-
default='memos-api')
210-
parser.add_argument("--version", type=str, default="0925-1", help="Version of the evaluation framework.")
211-
parser.add_argument("--workers", type=int, default=3, help="Number of parallel workers for processing users.")
235+
parser.add_argument(
236+
"--lib",
237+
type=str,
238+
choices=["mem0-local", "mem0-api", "memos-local", "memos-api", "zep"],
239+
default="memos-api",
240+
)
241+
parser.add_argument(
242+
"--version", type=str, default="0925-1", help="Version of the evaluation framework."
243+
)
244+
parser.add_argument(
245+
"--workers", type=int, default=3, help="Number of parallel workers for processing users."
246+
)
212247
args = parser.parse_args()
213248

214249
main(frame=args.lib, version=args.version, num_workers=args.workers)

0 commit comments

Comments
 (0)