Skip to content

Commit 5af3515

Browse files
committed
feat(locomo): 支持断点续传
1 parent e1c8dc6 commit 5af3515

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

evaluation/scripts/locomo/locomo_ingestion.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def ingest_session(client, session, frame, version, metadata):
8888
return elapsed_time
8989

9090

91-
def process_user(conv_idx, frame, locomo_df, version):
91+
def process_user(conv_idx, frame, locomo_df, version, success_records, f):
9292
conversation = locomo_df["conversation"].iloc[conv_idx]
9393
max_session_count = 35
9494
start_time = time.time()
@@ -149,11 +149,15 @@ def process_user(conv_idx, frame, locomo_df, version):
149149

150150
print(f"Processing {valid_sessions} sessions for user {conv_idx}")
151151

152-
for session, metadata in sessions_to_process:
153-
session_time = ingest_session(client, session, frame, version, metadata)
154-
total_session_time += session_time
155-
print(f"User {conv_idx}, {metadata['session_key']} processed in {session_time} seconds")
156-
152+
for session_idx, (session, metadata) in enumerate(sessions_to_process):
153+
if f"{conv_idx}_{session_idx}" not in success_records:
154+
session_time = ingest_session(client, session, frame, version, metadata)
155+
total_session_time += session_time
156+
print(f"User {conv_idx}, {metadata['session_key']} processed in {session_time} seconds")
157+
f.write(f"{conv_idx}_{session_idx}\n")
158+
f.flush()
159+
else:
160+
print(f"Session {conv_idx}_{session_idx} already ingested")
157161
end_time = time.time()
158162
elapsed_time = round(end_time - start_time, 2)
159163
print(f"User {conv_idx} processed successfully in {elapsed_time} seconds")
@@ -170,9 +174,17 @@ def main(frame, version="default", num_workers=4):
170174
print(
171175
f"Starting processing for {num_users} users in serial mode, each user using {num_workers} workers for sessions..."
172176
)
173-
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
177+
os.makedirs(f"results/locomo/{frame}-{version}/", exist_ok=True)
178+
success_records = []
179+
record_file = f"results/locomo/{frame}-{version}/success_records.txt"
180+
if os.path.exists(record_file):
181+
with open(record_file) as f:
182+
for i in f.readlines():
183+
success_records.append(i.strip())
184+
185+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor, open(record_file, "a+") as f:
174186
futures = [
175-
executor.submit(process_user, user_id, frame, locomo_df, version)
187+
executor.submit(process_user, user_id, frame, locomo_df, version, success_records, f)
176188
for user_id in range(num_users)
177189
]
178190
for future in concurrent.futures.as_completed(futures):
@@ -216,7 +228,7 @@ def main(frame, version="default", num_workers=4):
216228
help="Version identifier for saving results (e.g., 1010)",
217229
)
218230
parser.add_argument(
219-
"--workers", type=int, default=3, help="Number of parallel workers to process users"
231+
"--workers", type=int, default=10, help="Number of parallel workers to process users"
220232
)
221233
args = parser.parse_args()
222234
lib = args.lib

0 commit comments

Comments
 (0)