55python3 clean_chat_data.py
66"""
77import argparse
8- import datetime
98import json
109import os
10+ import hashlib
1111from pytz import timezone
12- import time
13-
12+ from functools import partial
13+ from math import ceil
14+ from datetime import datetime , timedelta
1415from tqdm import tqdm
16+ import time
17+ import multiprocessing as mp
1518
1619from fastchat .serve .monitor .basic_stats import NUM_SERVERS
1720from fastchat .serve .monitor .clean_battle_data import (
2629)
2730
2831
29- def get_log_files (max_num_files = None ):
30- dates = []
31- for month in range (4 , 12 ):
32- for day in range (1 , 33 ):
33- dates .append (f"2023-{ month :02d} -{ day :02d} " )
32+ def date_range (start = "2023-04-01" ):
33+ start_date = datetime .strptime (start , "%Y-%m-%d" ).date ()
34+ end_date = datetime .now ().date ()
35+ delta = end_date - start_date
36+ dates = [
37+ (start_date + timedelta (days = d )).strftime ("%Y-%m-%d" )
38+ for d in range (delta .days + 2 )
39+ ]
3440
41+ return dates
42+
43+
44+ def get_log_files (max_num_files = None ):
45+ dates = date_range ()
3546 filenames = []
3647 for d in dates :
3748 for i in range (NUM_SERVERS ):
@@ -44,90 +55,141 @@ def get_log_files(max_num_files=None):
4455 return filenames
4556
4657
47- def clean_chat_data (log_files , action_type ):
58+ def get_action_type_data (filename , action_type ):
59+ for _ in range (5 ):
60+ try :
61+ lines = open (filename ).readlines ()
62+ break
63+ except FileNotFoundError :
64+ time .sleep (2 )
65+
66+ rows = []
67+ for l in lines :
68+ row = json .loads (l )
69+ if row ["type" ] == action_type :
70+ rows .append (row )
71+ return rows
72+
73+
74+ def process_data (row , action_type ):
75+ try :
76+ if action_type in ["chat" , "upvote" , "downvote" ]:
77+ state = row ["state" ]
78+ model = row ["model" ]
79+ elif action_type == "leftvote" :
80+ state = row ["states" ][0 ]
81+ model = row ["states" ][0 ]["model_name" ]
82+ elif action_type == "rightvote" :
83+ state = row ["states" ][1 ]
84+ model = row ["states" ][1 ]["model_name" ]
85+ conversation_id = state ["conv_id" ]
86+ except KeyError :
87+ return {
88+ "ct_invalid_conv_id" : 1 ,
89+ }
90+
91+ if conversation_id is None :
92+ return {
93+ "ct_invalid_conv_id" : 1 ,
94+ }
95+
96+ conversation = to_openai_format (state ["messages" ][state ["offset" ] :])
97+ if not isinstance (model , str ):
98+ return {
99+ "ct_invalid" : 1 ,
100+ }
101+ model = replace_model_name (model , row ["tstamp" ])
102+
103+ try :
104+ lang_code = detect_language (state ["messages" ][state ["offset" ]][1 ])
105+ except IndexError :
106+ return {
107+ "ct_invalid" : 1 ,
108+ }
109+
110+ if not all (isinstance (x ["content" ], str ) for x in conversation ):
111+ return {
112+ "ct_invalid" : 1 ,
113+ }
114+
115+ messages = "" .join ([x ["content" ] for x in conversation ]).lower ()
116+ if NETWORK_ERROR_MSG in messages :
117+ return {
118+ "ct_network_error" : 1 ,
119+ }
120+ user_id = hashlib .md5 (row ["ip" ].encode ()).hexdigest ()
121+
122+ # Prepare the result data
123+ result = dict (
124+ conversation_id = conversation_id ,
125+ model = model ,
126+ conversation = conversation ,
127+ turn = len (conversation ) // 2 ,
128+ language = lang_code ,
129+ user_id = user_id ,
130+ tstamp = row ["tstamp" ],
131+ )
132+
133+ return {
134+ "result" : result ,
135+ "model" : model ,
136+ }
137+
138+
139+ def clean_chat_data (log_files , action_type , num_parallel ):
140+ with mp .Pool (num_parallel ) as pool :
141+ # Use partial to pass action_type to get_action_type_data
142+ func = partial (get_action_type_data , action_type = action_type )
143+ file_data = list (
144+ tqdm (
145+ pool .imap (
146+ func , log_files , chunksize = ceil (len (log_files ) / len (pool ._pool ))
147+ ),
148+ total = len (log_files ),
149+ desc = "Processing Log Files" ,
150+ )
151+ )
152+ # filter out Nones as some files may not contain any data belong to action_type
48153 raw_data = []
49- for filename in tqdm (log_files , desc = "read files" ):
50- for retry in range (5 ):
51- try :
52- lines = open (filename ).readlines ()
53- break
54- except FileNotFoundError :
55- time .sleep (2 )
56-
57- for l in lines :
58- row = json .loads (l )
59- if row ["type" ] == action_type :
60- raw_data .append (row )
154+ for data in file_data :
155+ raw_data .extend (data )
156+ raw_data = [r for r in raw_data if not (r is None )]
157+
158+ # Use the multiprocessing Pool
159+ with mp .Pool (num_parallel ) as pool :
160+ func = partial (process_data , action_type = action_type )
161+ results = list (
162+ tqdm (
163+ pool .imap (
164+ func , raw_data , chunksize = ceil (len (raw_data ) / len (pool ._pool ))
165+ ),
166+ total = len (raw_data ),
167+ desc = "Processing Raw Data" ,
168+ )
169+ )
61170
62- all_models = set ()
63- all_ips = dict ()
64- chats = []
171+ # Aggregate results from child processes
65172 ct_invalid_conv_id = 0
66173 ct_invalid = 0
67174 ct_network_error = 0
68- for row in raw_data :
69- try :
70- if action_type in ["chat" , "upvote" , "downvote" ]:
71- state = row ["state" ]
72- model = row ["model" ]
73- elif action_type == "leftvote" :
74- state = row ["states" ][0 ]
75- model = row ["states" ][0 ]["model_name" ]
76- elif action_type == "rightvote" :
77- state = row ["states" ][1 ]
78- model = row ["states" ][1 ]["model_name" ]
79- conversation_id = state ["conv_id" ]
80- except KeyError :
81- ct_invalid_conv_id += 1
82- continue
83-
84- if conversation_id is None :
85- ct_invalid_conv_id += 1
86- continue
87-
88- conversation = to_openai_format (state ["messages" ][state ["offset" ] :])
89- if not isinstance (model , str ):
90- ct_invalid += 1
91- continue
92- model = replace_model_name (model , row ["tstamp" ])
93-
94- try :
95- lang_code = detect_language (state ["messages" ][state ["offset" ]][1 ])
96- except IndexError :
97- ct_invalid += 1
175+ all_models = set ()
176+ chats = []
177+ for data in tqdm (results ):
178+ if "ct_invalid_conv_id" in data :
179+ ct_invalid_conv_id += data ["ct_invalid_conv_id" ]
98180 continue
99-
100- if not all (isinstance (x ["content" ], str ) for x in conversation ):
101- ct_invalid += 1
181+ if "ct_invalid" in data :
182+ ct_invalid += data ["ct_invalid" ]
102183 continue
103-
104- messages = "" .join ([x ["content" ] for x in conversation ]).lower ()
105- if NETWORK_ERROR_MSG in messages :
106- ct_network_error += 1
184+ if "ct_network_error" in data :
185+ ct_network_error += data ["ct_network_error" ]
107186 continue
108-
109- ip = row ["ip" ]
110- if ip not in all_ips :
111- all_ips [ip ] = len (all_ips )
112- user_id = all_ips [ip ]
113-
114- chats .append (
115- dict (
116- conversation_id = conversation_id ,
117- model = model ,
118- conversation = conversation ,
119- turn = len (conversation ) // 2 ,
120- language = lang_code ,
121- user_id = user_id ,
122- tstamp = row ["tstamp" ],
123- )
124- )
125-
126- all_models .update ([model ])
187+ all_models .update ([data ["model" ]])
188+ chats .append (data ["result" ])
127189
128190 chats .sort (key = lambda x : x ["tstamp" ])
129191 last_updated_tstamp = chats [- 1 ]["tstamp" ]
130- last_updated_datetime = datetime .datetime . fromtimestamp (
192+ last_updated_datetime = datetime .fromtimestamp (
131193 last_updated_tstamp , tz = timezone ("US/Pacific" )
132194 ).strftime ("%Y-%m-%d %H:%M:%S %Z" )
133195
@@ -156,12 +218,13 @@ def clean_chat_data(log_files, action_type):
156218 parser = argparse .ArgumentParser ()
157219 parser .add_argument ("--action-type" , type = str , default = "chat" )
158220 parser .add_argument ("--max-num-files" , type = int )
221+ parser .add_argument ("--num-parallel" , type = int , default = 16 )
159222 args = parser .parse_args ()
160223
161224 log_files = get_log_files (args .max_num_files )
162- chats = clean_chat_data (log_files , args .action_type )
225+ chats = clean_chat_data (log_files , args .action_type , args . num_parallel )
163226 last_updated_tstamp = chats [- 1 ]["tstamp" ]
164- cutoff_date = datetime .datetime . fromtimestamp (
227+ cutoff_date = datetime .fromtimestamp (
165228 last_updated_tstamp , tz = timezone ("US/Pacific" )
166229 ).strftime ("%Y%m%d" )
167230
0 commit comments