Skip to content

Commit 7ced22d

Browse files
committed
fix chunck size
1 parent 78941e8 commit 7ced22d

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

fastchat/serve/monitor/clean_chat_data.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import hashlib
1111
from pytz import timezone
1212
from functools import partial
13+
from math import ceil
1314
from datetime import datetime, timedelta
1415
import time
1516
import multiprocessing as mp
@@ -134,21 +135,25 @@ def process_data(row, action_type):
134135
}
135136

136137

137-
def clean_chat_data(log_files, action_type):
138-
with mp.Pool() as pool:
138+
def clean_chat_data(log_files, action_type, num_parallel):
139+
with mp.Pool(num_parallel) as pool:
139140
# Use partial to pass action_type to get_action_type_data
140141
func = partial(get_action_type_data, action_type=action_type)
141-
file_data = pool.map(func, log_files, chunksize=1)
142+
file_data = pool.map(
143+
func, log_files, chunksize=ceil(len(log_files) / len(pool._pool))
144+
)
142145
# filter out Nones as some files may not contain any data belong to action_type
143146
raw_data = []
144147
for data in file_data:
145148
raw_data.extend(data)
146149
raw_data = [r for r in raw_data if r is not None]
147150

148151
# Use the multiprocessing Pool
149-
with mp.Pool() as pool:
152+
with mp.Pool(num_parallel) as pool:
150153
func = partial(process_data, action_type=action_type)
151-
results = pool.map(func, raw_data, chunksize=1)
154+
results = pool.map(
155+
func, raw_data, chunksize=ceil(len(log_files) / len(pool._pool))
156+
)
152157

153158
# Aggregate results from child processes
154159
ct_invalid_conv_id = sum(
@@ -192,10 +197,11 @@ def clean_chat_data(log_files, action_type):
192197
parser = argparse.ArgumentParser()
193198
parser.add_argument("--action-type", type=str, default="chat")
194199
parser.add_argument("--max-num-files", type=int)
200+
parser.add_argument("--num-parallel", type=int, default=1)
195201
args = parser.parse_args()
196202

197203
log_files = get_log_files(args.max_num_files)
198-
chats = clean_chat_data(log_files, args.action_type)
204+
chats = clean_chat_data(log_files, args.action_type, args.num_parallel)
199205
last_updated_tstamp = chats[-1]["tstamp"]
200206
cutoff_date = datetime.fromtimestamp(
201207
last_updated_tstamp, tz=timezone("US/Pacific")

0 commit comments

Comments
 (0)