Skip to content

Commit 234e3b0

Browse files
committed
optimize clean_data script
1 parent 185e1a9 commit 234e3b0

File tree

1 file changed

+123
-76
lines changed

1 file changed

+123
-76
lines changed

fastchat/serve/monitor/clean_chat_data.py

Lines changed: 123 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
python3 clean_chat_data.py
66
"""
77
import argparse
8-
import datetime
98
import json
109
import os
1110
from pytz import timezone
11+
from functools import partial
12+
from datetime import datetime, timedelta
1213
import time
14+
import multiprocessing as mp
1315

1416
from tqdm import tqdm
1517

@@ -24,14 +26,24 @@
2426
NETWORK_ERROR_MSG = (
2527
"NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.".lower()
2628
)
29+
MANAGER = mp.Manager()
30+
LOCK = MANAGER.Lock()
2731

2832

29-
def get_log_files(max_num_files=None):
30-
dates = []
31-
for month in range(4, 12):
32-
for day in range(1, 33):
33-
dates.append(f"2023-{month:02d}-{day:02d}")
33+
def date_range(start="2023-04-01"):
34+
start_date = datetime.strptime(start, "%Y-%m-%d").date()
35+
end_date = datetime.now().date()
36+
delta = end_date - start_date
37+
dates = [
38+
(start_date + timedelta(days=d)).strftime("%Y-%m-%d")
39+
for d in range(delta.days + 2)
40+
]
41+
42+
return dates
43+
3444

45+
def get_log_files(max_num_files=None):
46+
dates = date_range()
3547
filenames = []
3648
for d in dates:
3749
for i in range(NUM_SERVERS):
@@ -44,90 +56,125 @@ def get_log_files(max_num_files=None):
4456
return filenames
4557

4658

47-
def clean_chat_data(log_files, action_type):
48-
raw_data = []
49-
for filename in tqdm(log_files, desc="read files"):
50-
for retry in range(5):
51-
try:
52-
lines = open(filename).readlines()
53-
break
54-
except FileNotFoundError:
55-
time.sleep(2)
56-
57-
for l in lines:
58-
row = json.loads(l)
59-
if row["type"] == action_type:
60-
raw_data.append(row)
59+
def get_action_type_data(filename, action_type):
60+
for _ in range(5):
61+
try:
62+
lines = open(filename).readlines()
63+
break
64+
except FileNotFoundError:
65+
time.sleep(2)
6166

62-
all_models = set()
63-
all_ips = dict()
64-
chats = []
67+
for l in lines:
68+
row = json.loads(l)
69+
if row["type"] == action_type:
70+
return row
71+
72+
73+
def process_data(row, action_type, all_ips):
74+
# Initialize local counters
6575
ct_invalid_conv_id = 0
6676
ct_invalid = 0
6777
ct_network_error = 0
68-
for row in raw_data:
69-
try:
70-
if action_type in ["chat", "upvote", "downvote"]:
71-
state = row["state"]
72-
model = row["model"]
73-
elif action_type == "leftvote":
74-
state = row["states"][0]
75-
model = row["states"][0]["model_name"]
76-
elif action_type == "rightvote":
77-
state = row["states"][1]
78-
model = row["states"][1]["model_name"]
79-
conversation_id = state["conv_id"]
80-
except KeyError:
81-
ct_invalid_conv_id += 1
82-
continue
8378

84-
if conversation_id is None:
85-
ct_invalid_conv_id += 1
86-
continue
79+
try:
80+
if action_type in ["chat", "upvote", "downvote"]:
81+
state = row["state"]
82+
model = row["model"]
83+
elif action_type == "leftvote":
84+
state = row["states"][0]
85+
model = row["states"][0]["model_name"]
86+
elif action_type == "rightvote":
87+
state = row["states"][1]
88+
model = row["states"][1]["model_name"]
89+
conversation_id = state["conv_id"]
90+
except KeyError:
91+
ct_invalid_conv_id += 1
92+
return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None
93+
94+
if conversation_id is None:
95+
ct_invalid_conv_id += 1
96+
return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None
97+
98+
conversation = to_openai_format(state["messages"][state["offset"] :])
99+
if not isinstance(model, str):
100+
ct_invalid += 1
101+
return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None
102+
model = replace_model_name(model, row["tstamp"])
103+
104+
try:
105+
lang_code = detect_language(state["messages"][state["offset"]][1])
106+
except IndexError:
107+
ct_invalid += 1
108+
return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None
109+
110+
if not all(isinstance(x["content"], str) for x in conversation):
111+
ct_invalid += 1
112+
return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None
113+
114+
messages = "".join([x["content"] for x in conversation]).lower()
115+
if NETWORK_ERROR_MSG in messages:
116+
ct_network_error += 1
117+
return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None
118+
119+
ip = row["ip"]
120+
# Synchronize access to all_ips using the lock
121+
with LOCK:
122+
if ip not in all_ips:
123+
all_ips[ip] = len(all_ips)
124+
user_id = all_ips[ip]
87125

88-
conversation = to_openai_format(state["messages"][state["offset"] :])
89-
if not isinstance(model, str):
90-
ct_invalid += 1
91-
continue
92-
model = replace_model_name(model, row["tstamp"])
126+
# Prepare the result data
127+
result = dict(
128+
conversation_id=conversation_id,
129+
model=model,
130+
conversation=conversation,
131+
turn=len(conversation) // 2,
132+
language=lang_code,
133+
user_id=user_id,
134+
tstamp=row["tstamp"],
135+
)
93136

94-
try:
95-
lang_code = detect_language(state["messages"][state["offset"]][1])
96-
except IndexError:
97-
ct_invalid += 1
98-
continue
137+
return result, ct_invalid_conv_id, ct_invalid, ct_network_error, model
99138

100-
if not all(isinstance(x["content"], str) for x in conversation):
101-
ct_invalid += 1
102-
continue
103139

104-
messages = "".join([x["content"] for x in conversation]).lower()
105-
if NETWORK_ERROR_MSG in messages:
106-
ct_network_error += 1
107-
continue
140+
def clean_chat_data(log_files, action_type):
141+
with mp.Pool() as pool:
142+
# Use partial to pass action_type to get_action_type_data
143+
func = partial(get_action_type_data, action_type=action_type)
144+
raw_data = pool.map(func, log_files, chunksize=1)
108145

109-
ip = row["ip"]
110-
if ip not in all_ips:
111-
all_ips[ip] = len(all_ips)
112-
user_id = all_ips[ip]
146+
# filter out Nones as some files may not contain any data belong to action_type
147+
raw_data = [r for r in raw_data if r is not None]
148+
all_ips = MANAGER.dict()
149+
150+
# Use the multiprocessing Pool
151+
with mp.Pool() as pool:
152+
func = partial(process_data, action_type=action_type, all_ips=all_ips)
153+
results = pool.map(func, raw_data, chunksize=1)
113154

114-
chats.append(
115-
dict(
116-
conversation_id=conversation_id,
117-
model=model,
118-
conversation=conversation,
119-
turn=len(conversation) // 2,
120-
language=lang_code,
121-
user_id=user_id,
122-
tstamp=row["tstamp"],
123-
)
124-
)
155+
# Initialize counters and collections in the parent process
156+
ct_invalid_conv_id = 0
157+
ct_invalid = 0
158+
ct_network_error = 0
159+
all_models = set()
160+
chats = []
125161

126-
all_models.update([model])
162+
# Aggregate results from child processes
163+
for res in results:
164+
if res is None:
165+
continue
166+
data, inv_conv_id, inv, net_err, model = res
167+
ct_invalid_conv_id += inv_conv_id
168+
ct_invalid += inv
169+
ct_network_error += net_err
170+
if data:
171+
chats.append(data)
172+
if model:
173+
all_models.add(model)
127174

128175
chats.sort(key=lambda x: x["tstamp"])
129176
last_updated_tstamp = chats[-1]["tstamp"]
130-
last_updated_datetime = datetime.datetime.fromtimestamp(
177+
last_updated_datetime = datetime.fromtimestamp(
131178
last_updated_tstamp, tz=timezone("US/Pacific")
132179
).strftime("%Y-%m-%d %H:%M:%S %Z")
133180

@@ -161,7 +208,7 @@ def clean_chat_data(log_files, action_type):
161208
log_files = get_log_files(args.max_num_files)
162209
chats = clean_chat_data(log_files, args.action_type)
163210
last_updated_tstamp = chats[-1]["tstamp"]
164-
cutoff_date = datetime.datetime.fromtimestamp(
211+
cutoff_date = datetime.fromtimestamp(
165212
last_updated_tstamp, tz=timezone("US/Pacific")
166213
).strftime("%Y%m%d")
167214

0 commit comments

Comments
 (0)