Skip to content

Commit e790f0b

Browse files
committed
add imap
1 parent e0cd21e commit e790f0b

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

fastchat/serve/monitor/clean_chat_data.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from functools import partial
1313
from math import ceil
1414
from datetime import datetime, timedelta
15+
from tqdm import tqdm
1516
import time
1617
import multiprocessing as mp
1718

@@ -139,8 +140,14 @@ def clean_chat_data(log_files, action_type, num_parallel):
139140
with mp.Pool(num_parallel) as pool:
140141
# Use partial to pass action_type to get_action_type_data
141142
func = partial(get_action_type_data, action_type=action_type)
142-
file_data = pool.map(
143-
func, log_files, chunksize=ceil(len(log_files) / len(pool._pool))
143+
file_data = list(
144+
tqdm(
145+
pool.imap(
146+
func, log_files, chunksize=ceil(len(log_files) / len(pool._pool))
147+
),
148+
total=len(log_files),
149+
desc="Processing Log Files",
150+
)
144151
)
145152
# filter out Nones as some files may not contain any data belong to action_type
146153
raw_data = []
@@ -151,8 +158,14 @@ def clean_chat_data(log_files, action_type, num_parallel):
151158
# Use the multiprocessing Pool
152159
with mp.Pool(num_parallel) as pool:
153160
func = partial(process_data, action_type=action_type)
154-
results = pool.map(
155-
func, raw_data, chunksize=ceil(len(raw_data) / len(pool._pool))
161+
results = list(
162+
tqdm(
163+
pool.imap(
164+
func, raw_data, chunksize=ceil(len(raw_data) / len(pool._pool))
165+
),
166+
total=len(raw_data),
167+
desc="Processing Raw Data",
168+
)
156169
)
157170

158171
# Aggregate results from child processes
@@ -161,7 +174,7 @@ def clean_chat_data(log_files, action_type, num_parallel):
161174
ct_network_error = 0
162175
all_models = set()
163176
chats = []
164-
for data in results:
177+
for data in tqdm(results):
165178
if "ct_invalid_conv_id" in data:
166179
ct_invalid_conv_id += data["ct_invalid_conv_id"]
167180
continue

0 commit comments

Comments
 (0)