Skip to content

Commit 5ac9372

Browse files
Clean Upvote and Downvote data (#3611)
1 parent 185e1a9 commit 5ac9372

File tree

1 file changed

+146
-83
lines changed

1 file changed

+146
-83
lines changed

fastchat/serve/monitor/clean_chat_data.py

Lines changed: 146 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
python3 clean_chat_data.py
66
"""
77
import argparse
8-
import datetime
98
import json
109
import os
10+
import hashlib
1111
from pytz import timezone
12-
import time
13-
12+
from functools import partial
13+
from math import ceil
14+
from datetime import datetime, timedelta
1415
from tqdm import tqdm
16+
import time
17+
import multiprocessing as mp
1518

1619
from fastchat.serve.monitor.basic_stats import NUM_SERVERS
1720
from fastchat.serve.monitor.clean_battle_data import (
@@ -26,12 +29,20 @@
2629
)
2730

2831

29-
def get_log_files(max_num_files=None):
30-
dates = []
31-
for month in range(4, 12):
32-
for day in range(1, 33):
33-
dates.append(f"2023-{month:02d}-{day:02d}")
32+
def date_range(start="2023-04-01"):
33+
start_date = datetime.strptime(start, "%Y-%m-%d").date()
34+
end_date = datetime.now().date()
35+
delta = end_date - start_date
36+
dates = [
37+
(start_date + timedelta(days=d)).strftime("%Y-%m-%d")
38+
for d in range(delta.days + 2)
39+
]
3440

41+
return dates
42+
43+
44+
def get_log_files(max_num_files=None):
45+
dates = date_range()
3546
filenames = []
3647
for d in dates:
3748
for i in range(NUM_SERVERS):
@@ -44,90 +55,141 @@ def get_log_files(max_num_files=None):
4455
return filenames
4556

4657

47-
def clean_chat_data(log_files, action_type):
58+
def get_action_type_data(filename, action_type):
59+
for _ in range(5):
60+
try:
61+
lines = open(filename).readlines()
62+
break
63+
except FileNotFoundError:
64+
time.sleep(2)
65+
66+
rows = []
67+
for l in lines:
68+
row = json.loads(l)
69+
if row["type"] == action_type:
70+
rows.append(row)
71+
return rows
72+
73+
74+
def process_data(row, action_type):
75+
try:
76+
if action_type in ["chat", "upvote", "downvote"]:
77+
state = row["state"]
78+
model = row["model"]
79+
elif action_type == "leftvote":
80+
state = row["states"][0]
81+
model = row["states"][0]["model_name"]
82+
elif action_type == "rightvote":
83+
state = row["states"][1]
84+
model = row["states"][1]["model_name"]
85+
conversation_id = state["conv_id"]
86+
except KeyError:
87+
return {
88+
"ct_invalid_conv_id": 1,
89+
}
90+
91+
if conversation_id is None:
92+
return {
93+
"ct_invalid_conv_id": 1,
94+
}
95+
96+
conversation = to_openai_format(state["messages"][state["offset"] :])
97+
if not isinstance(model, str):
98+
return {
99+
"ct_invalid": 1,
100+
}
101+
model = replace_model_name(model, row["tstamp"])
102+
103+
try:
104+
lang_code = detect_language(state["messages"][state["offset"]][1])
105+
except IndexError:
106+
return {
107+
"ct_invalid": 1,
108+
}
109+
110+
if not all(isinstance(x["content"], str) for x in conversation):
111+
return {
112+
"ct_invalid": 1,
113+
}
114+
115+
messages = "".join([x["content"] for x in conversation]).lower()
116+
if NETWORK_ERROR_MSG in messages:
117+
return {
118+
"ct_network_error": 1,
119+
}
120+
user_id = hashlib.md5(row["ip"].encode()).hexdigest()
121+
122+
# Prepare the result data
123+
result = dict(
124+
conversation_id=conversation_id,
125+
model=model,
126+
conversation=conversation,
127+
turn=len(conversation) // 2,
128+
language=lang_code,
129+
user_id=user_id,
130+
tstamp=row["tstamp"],
131+
)
132+
133+
return {
134+
"result": result,
135+
"model": model,
136+
}
137+
138+
139+
def clean_chat_data(log_files, action_type, num_parallel):
140+
with mp.Pool(num_parallel) as pool:
141+
# Use partial to pass action_type to get_action_type_data
142+
func = partial(get_action_type_data, action_type=action_type)
143+
file_data = list(
144+
tqdm(
145+
pool.imap(
146+
func, log_files, chunksize=ceil(len(log_files) / len(pool._pool))
147+
),
148+
total=len(log_files),
149+
desc="Processing Log Files",
150+
)
151+
)
152+
# filter out Nones as some files may not contain any data belong to action_type
48153
raw_data = []
49-
for filename in tqdm(log_files, desc="read files"):
50-
for retry in range(5):
51-
try:
52-
lines = open(filename).readlines()
53-
break
54-
except FileNotFoundError:
55-
time.sleep(2)
56-
57-
for l in lines:
58-
row = json.loads(l)
59-
if row["type"] == action_type:
60-
raw_data.append(row)
154+
for data in file_data:
155+
raw_data.extend(data)
156+
raw_data = [r for r in raw_data if not (r is None)]
157+
158+
# Use the multiprocessing Pool
159+
with mp.Pool(num_parallel) as pool:
160+
func = partial(process_data, action_type=action_type)
161+
results = list(
162+
tqdm(
163+
pool.imap(
164+
func, raw_data, chunksize=ceil(len(raw_data) / len(pool._pool))
165+
),
166+
total=len(raw_data),
167+
desc="Processing Raw Data",
168+
)
169+
)
61170

62-
all_models = set()
63-
all_ips = dict()
64-
chats = []
171+
# Aggregate results from child processes
65172
ct_invalid_conv_id = 0
66173
ct_invalid = 0
67174
ct_network_error = 0
68-
for row in raw_data:
69-
try:
70-
if action_type in ["chat", "upvote", "downvote"]:
71-
state = row["state"]
72-
model = row["model"]
73-
elif action_type == "leftvote":
74-
state = row["states"][0]
75-
model = row["states"][0]["model_name"]
76-
elif action_type == "rightvote":
77-
state = row["states"][1]
78-
model = row["states"][1]["model_name"]
79-
conversation_id = state["conv_id"]
80-
except KeyError:
81-
ct_invalid_conv_id += 1
82-
continue
83-
84-
if conversation_id is None:
85-
ct_invalid_conv_id += 1
86-
continue
87-
88-
conversation = to_openai_format(state["messages"][state["offset"] :])
89-
if not isinstance(model, str):
90-
ct_invalid += 1
91-
continue
92-
model = replace_model_name(model, row["tstamp"])
93-
94-
try:
95-
lang_code = detect_language(state["messages"][state["offset"]][1])
96-
except IndexError:
97-
ct_invalid += 1
175+
all_models = set()
176+
chats = []
177+
for data in tqdm(results):
178+
if "ct_invalid_conv_id" in data:
179+
ct_invalid_conv_id += data["ct_invalid_conv_id"]
98180
continue
99-
100-
if not all(isinstance(x["content"], str) for x in conversation):
101-
ct_invalid += 1
181+
if "ct_invalid" in data:
182+
ct_invalid += data["ct_invalid"]
102183
continue
103-
104-
messages = "".join([x["content"] for x in conversation]).lower()
105-
if NETWORK_ERROR_MSG in messages:
106-
ct_network_error += 1
184+
if "ct_network_error" in data:
185+
ct_network_error += data["ct_network_error"]
107186
continue
108-
109-
ip = row["ip"]
110-
if ip not in all_ips:
111-
all_ips[ip] = len(all_ips)
112-
user_id = all_ips[ip]
113-
114-
chats.append(
115-
dict(
116-
conversation_id=conversation_id,
117-
model=model,
118-
conversation=conversation,
119-
turn=len(conversation) // 2,
120-
language=lang_code,
121-
user_id=user_id,
122-
tstamp=row["tstamp"],
123-
)
124-
)
125-
126-
all_models.update([model])
187+
all_models.update([data["model"]])
188+
chats.append(data["result"])
127189

128190
chats.sort(key=lambda x: x["tstamp"])
129191
last_updated_tstamp = chats[-1]["tstamp"]
130-
last_updated_datetime = datetime.datetime.fromtimestamp(
192+
last_updated_datetime = datetime.fromtimestamp(
131193
last_updated_tstamp, tz=timezone("US/Pacific")
132194
).strftime("%Y-%m-%d %H:%M:%S %Z")
133195

@@ -156,12 +218,13 @@ def clean_chat_data(log_files, action_type):
156218
parser = argparse.ArgumentParser()
157219
parser.add_argument("--action-type", type=str, default="chat")
158220
parser.add_argument("--max-num-files", type=int)
221+
parser.add_argument("--num-parallel", type=int, default=16)
159222
args = parser.parse_args()
160223

161224
log_files = get_log_files(args.max_num_files)
162-
chats = clean_chat_data(log_files, args.action_type)
225+
chats = clean_chat_data(log_files, args.action_type, args.num_parallel)
163226
last_updated_tstamp = chats[-1]["tstamp"]
164-
cutoff_date = datetime.datetime.fromtimestamp(
227+
cutoff_date = datetime.fromtimestamp(
165228
last_updated_tstamp, tz=timezone("US/Pacific")
166229
).strftime("%Y%m%d")
167230

0 commit comments

Comments
 (0)