|
10 | 10 | import hashlib |
11 | 11 | from pytz import timezone |
12 | 12 | from functools import partial |
| 13 | +from math import ceil |
13 | 14 | from datetime import datetime, timedelta |
14 | 15 | import time |
15 | 16 | import multiprocessing as mp |
@@ -134,21 +135,25 @@ def process_data(row, action_type): |
134 | 135 | } |
135 | 136 |
|
136 | 137 |
|
137 | | -def clean_chat_data(log_files, action_type): |
138 | | - with mp.Pool() as pool: |
| 138 | +def clean_chat_data(log_files, action_type, num_parallel): |
| 139 | + with mp.Pool(num_parallel) as pool: |
139 | 140 | # Use partial to pass action_type to get_action_type_data |
140 | 141 | func = partial(get_action_type_data, action_type=action_type) |
141 | | - file_data = pool.map(func, log_files, chunksize=1) |
| 142 | + file_data = pool.map( |
| 143 | + func, log_files, chunksize=ceil(len(log_files) / len(pool._pool)) |
| 144 | + ) |
142 | 145 | # filter out Nones as some files may not contain any data belong to action_type |
143 | 146 | raw_data = [] |
144 | 147 | for data in file_data: |
145 | 148 | raw_data.extend(data) |
146 | 149 | raw_data = [r for r in raw_data if r is not None] |
147 | 150 |
|
148 | 151 | # Use the multiprocessing Pool |
149 | | - with mp.Pool() as pool: |
| 152 | + with mp.Pool(num_parallel) as pool: |
150 | 153 | func = partial(process_data, action_type=action_type) |
151 | | - results = pool.map(func, raw_data, chunksize=1) |
| 154 | + results = pool.map( |
| 155 | + func, raw_data, chunksize=ceil(len(log_files) / len(pool._pool)) |
| 156 | + ) |
152 | 157 |
|
153 | 158 | # Aggregate results from child processes |
154 | 159 | ct_invalid_conv_id = sum( |
@@ -192,10 +197,11 @@ def clean_chat_data(log_files, action_type): |
192 | 197 | parser = argparse.ArgumentParser() |
193 | 198 | parser.add_argument("--action-type", type=str, default="chat") |
194 | 199 | parser.add_argument("--max-num-files", type=int) |
| 200 | + parser.add_argument("--num-parallel", type=int, default=1) |
195 | 201 | args = parser.parse_args() |
196 | 202 |
|
197 | 203 | log_files = get_log_files(args.max_num_files) |
198 | | - chats = clean_chat_data(log_files, args.action_type) |
| 204 | + chats = clean_chat_data(log_files, args.action_type, args.num_parallel) |
199 | 205 | last_updated_tstamp = chats[-1]["tstamp"] |
200 | 206 | cutoff_date = datetime.fromtimestamp( |
201 | 207 | last_updated_tstamp, tz=timezone("US/Pacific") |
|
0 commit comments