Skip to content

Commit 0fda336

Browse files
committed
add breakpoint in eval scripts
1 parent 39a4f29 commit 0fda336

File tree

8 files changed

+292
-138
lines changed

8 files changed

+292
-138
lines changed

evaluation/scripts/PrefEval/pref_mem0.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,13 @@
2929

3030

3131
def add_memory_for_line(
32-
line_data: tuple, mem_client, num_irrelevant_turns: int, lib: str, version: str
32+
line_data: tuple,
33+
mem_client,
34+
num_irrelevant_turns: int,
35+
lib: str,
36+
version: str,
37+
success_records,
38+
f,
3339
) -> dict:
3440
"""
3541
Adds conversation memory for a single line of data to MemOS and returns the data with a persistent user_id.
@@ -46,13 +52,22 @@ def add_memory_for_line(
4652
elif num_irrelevant_turns == 300:
4753
conversation = conversation + irre_300
4854

49-
turns_add = 5
5055
start_time_add = time.monotonic()
51-
if conversation:
52-
for chunk_start in range(0, len(conversation), turns_add * 2):
53-
chunk = conversation[chunk_start : chunk_start + turns_add * 2]
54-
timestamp_add = int(time.time() * 100)
55-
mem_client.add(messages=chunk, user_id=user_id, timestamp=timestamp_add)
56+
57+
for idx, _ in enumerate(conversation[::2]):
58+
msg_idx = idx * 2
59+
record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}"
60+
timestamp_add = int(time.time() * 100)
61+
62+
if record_id not in success_records:
63+
mem_client.add(
64+
messages=conversation[msg_idx : msg_idx + 2],
65+
user_id=user_id,
66+
timestamp=timestamp_add,
67+
)
68+
f.write(f"{record_id}\n")
69+
f.flush()
70+
5671
end_time_add = time.monotonic()
5772
add_duration = end_time_add - start_time_add
5873

@@ -210,6 +225,15 @@ def main():
210225
from utils.client import Mem0Client
211226

212227
mem_client = Mem0Client(enable_graph="graph" in args.lib)
228+
os.makedirs(f"results/prefeval/{args.lib}_{args.version}", exist_ok=True)
229+
success_records = set()
230+
record_file = f"results/prefeval/{args.lib}_{args.version}/success_records.txt"
231+
if os.path.exists(record_file):
232+
print(f"Loading existing success records from {record_file}...")
233+
with open(record_file, encoding="utf-8") as f:
234+
for i in f.readlines():
235+
success_records.add(i.strip())
236+
print(f"Loaded {len(success_records)} records.")
213237

214238
if args.mode == "add":
215239
print(f"Running in 'add' mode. Ingesting memories from '{args.input}'...")
@@ -218,6 +242,7 @@ def main():
218242
with (
219243
open(args.output, "w", encoding="utf-8") as outfile,
220244
concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor,
245+
open(record_file, "a+", encoding="utf-8") as f,
221246
):
222247
futures = [
223248
executor.submit(
@@ -227,6 +252,8 @@ def main():
227252
args.add_turn,
228253
args.lib,
229254
args.version,
255+
success_records,
256+
f,
230257
)
231258
for i, line in enumerate(lines)
232259
]

evaluation/scripts/PrefEval/pref_memobase.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from openai import OpenAI
1313
from tqdm import tqdm
1414

15-
1615
ROOT_DIR = os.path.dirname(
1716
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
1817
)
@@ -28,16 +27,20 @@
2827

2928

3029
def add_memory_for_line(
31-
line_data: tuple, mem_client, num_irrelevant_turns: int, lib: str, version: str
30+
line_data: tuple,
31+
mem_client,
32+
num_irrelevant_turns: int,
33+
lib: str,
34+
version: str,
35+
success_records,
36+
f,
3237
) -> dict:
3338
"""
3439
Adds conversation memory for a single line of data to MemOS and returns the data with a persistent user_id.
3540
"""
3641
i, line = line_data
3742
user_id = f"{lib}_user_pref_eval_{i}_{version}"
3843
mem_client.delete_user(user_id)
39-
user_id = mem_client.client.add_user({"user_id": user_id})
40-
print("user_id:", user_id)
4144
try:
4245
original_data = json.loads(line)
4346
conversation = original_data.get("conversation", [])
@@ -63,7 +66,14 @@ def add_memory_for_line(
6366
"created_at": timestamp_add,
6467
}
6568
)
66-
mem_client.add(messages=messages, user_id=user_id)
69+
for idx, _ in enumerate(conversation[::2]):
70+
msg_idx = idx * 2
71+
record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}"
72+
73+
if record_id not in success_records:
74+
mem_client.add(messages=conversation[msg_idx : msg_idx + 2], user_id=user_id)
75+
f.write(f"{record_id}\n")
76+
f.flush()
6777

6878
end_time_add = time.monotonic()
6979
add_duration = end_time_add - start_time_add
@@ -222,13 +232,24 @@ def main():
222232

223233
mem_client = MemobaseClient()
224234

235+
os.makedirs(f"results/prefeval/{args.lib}_{args.version}", exist_ok=True)
236+
success_records = set()
237+
record_file = f"results/prefeval/{args.lib}_{args.version}/success_records.txt"
238+
if os.path.exists(record_file):
239+
print(f"Loading existing success records from {record_file}...")
240+
with open(record_file, encoding="utf-8") as f:
241+
for i in f.readlines():
242+
success_records.add(i.strip())
243+
print(f"Loaded {len(success_records)} records.")
244+
225245
if args.mode == "add":
226246
print(f"Running in 'add' mode. Ingesting memories from '{args.input}'...")
227247
print(f"Adding {args.add_turn} irrelevant turns.")
228248
print(f"Using {args.max_workers} workers.")
229249
with (
230250
open(args.output, "w", encoding="utf-8") as outfile,
231251
concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor,
252+
open(record_file, "a+", encoding="utf-8") as f,
232253
):
233254
futures = [
234255
executor.submit(
@@ -238,6 +259,8 @@ def main():
238259
args.add_turn,
239260
args.lib,
240261
args.version,
262+
success_records,
263+
f,
241264
)
242265
for i, line in enumerate(lines)
243266
]

evaluation/scripts/PrefEval/pref_memos.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from openai import OpenAI
1313
from tqdm import tqdm
1414

15-
1615
ROOT_DIR = os.path.dirname(
1716
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
1817
)
@@ -21,7 +20,6 @@
2120
sys.path.insert(0, ROOT_DIR)
2221
sys.path.insert(0, EVAL_SCRIPTS_DIR)
2322

24-
2523
load_dotenv()
2624
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
2725
BASE_URL = os.getenv("OPENAI_BASE_URL")
@@ -30,8 +28,8 @@
3028

3129

3230
def add_memory_for_line(
33-
line_data: tuple, mem_client, num_irrelevant_turns: int, lib: str, version: str
34-
) -> dict:
31+
line_data, mem_client, num_irrelevant_turns, lib, version, success_records, f
32+
):
3533
"""
3634
Adds conversation memory for a single line of data to MemOS and returns the data with a persistent user_id.
3735
"""
@@ -47,15 +45,22 @@ def add_memory_for_line(
4745
elif num_irrelevant_turns == 300:
4846
conversation = conversation + irre_300
4947

50-
turns_add = 5
5148
start_time_add = time.monotonic()
52-
if conversation:
53-
if os.getenv("PRE_SPLIT_CHUNK", "false").lower() == "true":
54-
for chunk_start in range(0, len(conversation), turns_add * 2):
55-
chunk = conversation[chunk_start : chunk_start + turns_add * 2]
56-
mem_client.add(messages=chunk, user_id=user_id, conv_id=None, batch_size=2)
57-
else:
58-
mem_client.add(messages=conversation, user_id=user_id, conv_id=None, batch_size=2)
49+
50+
for idx, _ in enumerate(conversation[::2]):
51+
msg_idx = idx * 2
52+
record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}"
53+
54+
if record_id not in success_records:
55+
mem_client.add(
56+
messages=conversation[msg_idx : msg_idx + 2],
57+
user_id=user_id,
58+
conv_id=None,
59+
batch_size=2,
60+
)
61+
f.write(f"{record_id}\n")
62+
f.flush()
63+
5964
end_time_add = time.monotonic()
6065
add_duration = end_time_add - start_time_add
6166

@@ -68,7 +73,7 @@ def add_memory_for_line(
6873
return None
6974

7075

71-
def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> dict:
76+
def search_memory_for_line(line_data, mem_client, top_k_value):
7277
"""
7378
Processes a single line of data, searching memory based on the question.
7479
"""
@@ -120,7 +125,7 @@ def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> di
120125
return None
121126

122127

123-
def generate_response_for_line(line_data: tuple, openai_client: OpenAI, lib: str) -> dict:
128+
def generate_response_for_line(line_data, openai_client, lib):
124129
"""
125130
Generates a response for a single line of data using pre-fetched memories.
126131
"""
@@ -195,7 +200,7 @@ def main():
195200
parser.add_argument(
196201
"--lib",
197202
type=str,
198-
choices=["memos-api", "memos-local"],
203+
choices=["memos-api", "memos-api-online"],
199204
default="memos-api",
200205
help="Which MemOS library to use (used in 'add' mode).",
201206
)
@@ -218,9 +223,22 @@ def main():
218223
print(f"Error: Input file '{args.input}' not found")
219224
return
220225

221-
from utils.client import MemosApiClient
226+
from utils.client import MemosApiClient, MemosApiOnlineClient
227+
228+
if args.lib == "memos-api":
229+
mem_client = MemosApiClient()
230+
elif args.lib == "memos-api-online":
231+
mem_client = MemosApiOnlineClient()
222232

223-
mem_client = MemosApiClient()
233+
os.makedirs(f"results/prefeval/{args.lib}_{args.version}", exist_ok=True)
234+
success_records = set()
235+
record_file = f"results/prefeval/{args.lib}_{args.version}/success_records.txt"
236+
if os.path.exists(record_file):
237+
print(f"Loading existing success records from {record_file}...")
238+
with open(record_file, encoding="utf-8") as f:
239+
for i in f.readlines():
240+
success_records.add(i.strip())
241+
print(f"Loaded {len(success_records)} records.")
224242

225243
if args.mode == "add":
226244
print(f"Running in 'add' mode. Ingesting memories from '{args.input}'...")
@@ -229,6 +247,7 @@ def main():
229247
with (
230248
open(args.output, "w", encoding="utf-8") as outfile,
231249
concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor,
250+
open(record_file, "a+", encoding="utf-8") as record_f,
232251
):
233252
futures = [
234253
executor.submit(
@@ -238,6 +257,8 @@ def main():
238257
args.add_turn,
239258
args.lib,
240259
args.version,
260+
success_records,
261+
record_f,
241262
)
242263
for i, line in enumerate(lines)
243264
]

evaluation/scripts/PrefEval/pref_memu.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from openai import OpenAI
1515
from tqdm import tqdm
1616

17-
1817
ROOT_DIR = os.path.dirname(
1918
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
2019
)
@@ -30,7 +29,13 @@
3029

3130

3231
def add_memory_for_line(
33-
line_data: tuple, mem_client, num_irrelevant_turns: int, lib: str, version: str
32+
line_data: tuple,
33+
mem_client,
34+
num_irrelevant_turns: int,
35+
lib: str,
36+
version: str,
37+
success_records,
38+
f,
3439
) -> dict:
3540
"""
3641
Adds conversation memory for a single line of data to MemOS and returns the data with a persistent user_id.
@@ -47,19 +52,21 @@ def add_memory_for_line(
4752
elif num_irrelevant_turns == 300:
4853
conversation = conversation + irre_300
4954

50-
turns_add = 5
5155
start_time_add = time.monotonic()
52-
if conversation:
53-
if os.getenv("PRE_SPLIT_CHUNK", "false").lower() == "true":
54-
for chunk_start in range(0, len(conversation), turns_add * 2):
55-
chunk = conversation[chunk_start : chunk_start + turns_add * 2]
56-
mem_client.add(
57-
messages=chunk, user_id=user_id, iso_date=datetime.now().isoformat()
58-
)
59-
else:
56+
57+
for idx, _ in enumerate(conversation[::2]):
58+
msg_idx = idx * 2
59+
record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}"
60+
61+
if record_id not in success_records:
6062
mem_client.add(
61-
messages=conversation, user_id=user_id, iso_date=datetime.now().isoformat()
63+
messages=conversation[msg_idx : msg_idx + 2],
64+
user_id=user_id,
65+
iso_date=datetime.now().isoformat(),
6266
)
67+
f.write(f"{record_id}\n")
68+
f.flush()
69+
6370
end_time_add = time.monotonic()
6471
add_duration = end_time_add - start_time_add
6572

@@ -219,13 +226,24 @@ def main():
219226

220227
mem_client = MemuClient()
221228

229+
os.makedirs(f"results/prefeval/{args.lib}_{args.version}", exist_ok=True)
230+
success_records = set()
231+
record_file = f"results/prefeval/{args.lib}_{args.version}/success_records.txt"
232+
if os.path.exists(record_file):
233+
print(f"Loading existing success records from {record_file}...")
234+
with open(record_file, encoding="utf-8") as f:
235+
for i in f.readlines():
236+
success_records.add(i.strip())
237+
print(f"Loaded {len(success_records)} records.")
238+
222239
if args.mode == "add":
223240
print(f"Running in 'add' mode. Ingesting memories from '{args.input}'...")
224241
print(f"Adding {args.add_turn} irrelevant turns.")
225242
print(f"Using {args.max_workers} workers.")
226243
with (
227244
open(args.output, "w", encoding="utf-8") as outfile,
228245
concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor,
246+
open(record_file, "a+", encoding="utf-8") as f,
229247
):
230248
futures = [
231249
executor.submit(
@@ -235,6 +253,8 @@ def main():
235253
args.add_turn,
236254
args.lib,
237255
args.version,
256+
success_records,
257+
f,
238258
)
239259
for i, line in enumerate(lines)
240260
]

0 commit comments

Comments
 (0)