55python3 clean_chat_data.py
66"""
77import argparse
8- import datetime
98import json
109import os
1110from pytz import timezone
11+ from functools import partial
12+ from datetime import datetime , timedelta
1213import time
14+ import multiprocessing as mp
1315
1416from tqdm import tqdm
1517
2426NETWORK_ERROR_MSG = (
2527 "NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE." .lower ()
2628)
29+ MANAGER = mp .Manager ()
30+ LOCK = MANAGER .Lock ()
2731
2832
29- def get_log_files (max_num_files = None ):
30- dates = []
31- for month in range (4 , 12 ):
32- for day in range (1 , 33 ):
33- dates .append (f"2023-{ month :02d} -{ day :02d} " )
33+ def date_range (start = "2023-04-01" ):
34+ start_date = datetime .strptime (start , "%Y-%m-%d" ).date ()
35+ end_date = datetime .now ().date ()
36+ delta = end_date - start_date
37+ dates = [
38+ (start_date + timedelta (days = d )).strftime ("%Y-%m-%d" )
39+ for d in range (delta .days + 2 )
40+ ]
41+
42+ return dates
43+
3444
45+ def get_log_files (max_num_files = None ):
46+ dates = date_range ()
3547 filenames = []
3648 for d in dates :
3749 for i in range (NUM_SERVERS ):
@@ -44,90 +56,125 @@ def get_log_files(max_num_files=None):
4456 return filenames
4557
4658
47- def clean_chat_data (log_files , action_type ):
48- raw_data = []
49- for filename in tqdm (log_files , desc = "read files" ):
50- for retry in range (5 ):
51- try :
52- lines = open (filename ).readlines ()
53- break
54- except FileNotFoundError :
55- time .sleep (2 )
56-
57- for l in lines :
58- row = json .loads (l )
59- if row ["type" ] == action_type :
60- raw_data .append (row )
59+ def get_action_type_data (filename , action_type ):
60+ for _ in range (5 ):
61+ try :
62+ lines = open (filename ).readlines ()
63+ break
64+ except FileNotFoundError :
65+ time .sleep (2 )
6166
62- all_models = set ()
63- all_ips = dict ()
64- chats = []
67+ for l in lines :
68+ row = json .loads (l )
69+ if row ["type" ] == action_type :
70+ return row
71+
72+
73+ def process_data (row , action_type , all_ips ):
74+ # Initialize local counters
6575 ct_invalid_conv_id = 0
6676 ct_invalid = 0
6777 ct_network_error = 0
68- for row in raw_data :
69- try :
70- if action_type in ["chat" , "upvote" , "downvote" ]:
71- state = row ["state" ]
72- model = row ["model" ]
73- elif action_type == "leftvote" :
74- state = row ["states" ][0 ]
75- model = row ["states" ][0 ]["model_name" ]
76- elif action_type == "rightvote" :
77- state = row ["states" ][1 ]
78- model = row ["states" ][1 ]["model_name" ]
79- conversation_id = state ["conv_id" ]
80- except KeyError :
81- ct_invalid_conv_id += 1
82- continue
8378
84- if conversation_id is None :
85- ct_invalid_conv_id += 1
86- continue
79+ try :
80+ if action_type in ["chat" , "upvote" , "downvote" ]:
81+ state = row ["state" ]
82+ model = row ["model" ]
83+ elif action_type == "leftvote" :
84+ state = row ["states" ][0 ]
85+ model = row ["states" ][0 ]["model_name" ]
86+ elif action_type == "rightvote" :
87+ state = row ["states" ][1 ]
88+ model = row ["states" ][1 ]["model_name" ]
89+ conversation_id = state ["conv_id" ]
90+ except KeyError :
91+ ct_invalid_conv_id += 1
92+ return None , ct_invalid_conv_id , ct_invalid , ct_network_error , None
93+
94+ if conversation_id is None :
95+ ct_invalid_conv_id += 1
96+ return None , ct_invalid_conv_id , ct_invalid , ct_network_error , None
97+
98+ conversation = to_openai_format (state ["messages" ][state ["offset" ] :])
99+ if not isinstance (model , str ):
100+ ct_invalid += 1
101+ return None , ct_invalid_conv_id , ct_invalid , ct_network_error , None
102+ model = replace_model_name (model , row ["tstamp" ])
103+
104+ try :
105+ lang_code = detect_language (state ["messages" ][state ["offset" ]][1 ])
106+ except IndexError :
107+ ct_invalid += 1
108+ return None , ct_invalid_conv_id , ct_invalid , ct_network_error , None
109+
110+ if not all (isinstance (x ["content" ], str ) for x in conversation ):
111+ ct_invalid += 1
112+ return None , ct_invalid_conv_id , ct_invalid , ct_network_error , None
113+
114+ messages = "" .join ([x ["content" ] for x in conversation ]).lower ()
115+ if NETWORK_ERROR_MSG in messages :
116+ ct_network_error += 1
117+ return None , ct_invalid_conv_id , ct_invalid , ct_network_error , None
118+
119+ ip = row ["ip" ]
120+ # Synchronize access to all_ips using the lock
121+ with LOCK :
122+ if ip not in all_ips :
123+ all_ips [ip ] = len (all_ips )
124+ user_id = all_ips [ip ]
87125
88- conversation = to_openai_format (state ["messages" ][state ["offset" ] :])
89- if not isinstance (model , str ):
90- ct_invalid += 1
91- continue
92- model = replace_model_name (model , row ["tstamp" ])
126+ # Prepare the result data
127+ result = dict (
128+ conversation_id = conversation_id ,
129+ model = model ,
130+ conversation = conversation ,
131+ turn = len (conversation ) // 2 ,
132+ language = lang_code ,
133+ user_id = user_id ,
134+ tstamp = row ["tstamp" ],
135+ )
93136
94- try :
95- lang_code = detect_language (state ["messages" ][state ["offset" ]][1 ])
96- except IndexError :
97- ct_invalid += 1
98- continue
137+ return result , ct_invalid_conv_id , ct_invalid , ct_network_error , model
99138
100- if not all (isinstance (x ["content" ], str ) for x in conversation ):
101- ct_invalid += 1
102- continue
103139
104- messages = "" .join ([x ["content" ] for x in conversation ]).lower ()
105- if NETWORK_ERROR_MSG in messages :
106- ct_network_error += 1
107- continue
140+ def clean_chat_data (log_files , action_type ):
141+ with mp .Pool () as pool :
142+ # Use partial to pass action_type to get_action_type_data
143+ func = partial (get_action_type_data , action_type = action_type )
144+ raw_data = pool .map (func , log_files , chunksize = 1 )
108145
109- ip = row ["ip" ]
110- if ip not in all_ips :
111- all_ips [ip ] = len (all_ips )
112- user_id = all_ips [ip ]
146+ # filter out Nones as some files may not contain any data belong to action_type
147+ raw_data = [r for r in raw_data if r is not None ]
148+ all_ips = MANAGER .dict ()
149+
150+ # Use the multiprocessing Pool
151+ with mp .Pool () as pool :
152+ func = partial (process_data , action_type = action_type , all_ips = all_ips )
153+ results = pool .map (func , raw_data , chunksize = 1 )
113154
114- chats .append (
115- dict (
116- conversation_id = conversation_id ,
117- model = model ,
118- conversation = conversation ,
119- turn = len (conversation ) // 2 ,
120- language = lang_code ,
121- user_id = user_id ,
122- tstamp = row ["tstamp" ],
123- )
124- )
155+ # Initialize counters and collections in the parent process
156+ ct_invalid_conv_id = 0
157+ ct_invalid = 0
158+ ct_network_error = 0
159+ all_models = set ()
160+ chats = []
125161
126- all_models .update ([model ])
162+ # Aggregate results from child processes
163+ for res in results :
164+ if res is None :
165+ continue
166+ data , inv_conv_id , inv , net_err , model = res
167+ ct_invalid_conv_id += inv_conv_id
168+ ct_invalid += inv
169+ ct_network_error += net_err
170+ if data :
171+ chats .append (data )
172+ if model :
173+ all_models .add (model )
127174
128175 chats .sort (key = lambda x : x ["tstamp" ])
129176 last_updated_tstamp = chats [- 1 ]["tstamp" ]
130- last_updated_datetime = datetime .datetime . fromtimestamp (
177+ last_updated_datetime = datetime .fromtimestamp (
131178 last_updated_tstamp , tz = timezone ("US/Pacific" )
132179 ).strftime ("%Y-%m-%d %H:%M:%S %Z" )
133180
@@ -161,7 +208,7 @@ def clean_chat_data(log_files, action_type):
161208 log_files = get_log_files (args .max_num_files )
162209 chats = clean_chat_data (log_files , args .action_type )
163210 last_updated_tstamp = chats [- 1 ]["tstamp" ]
164- cutoff_date = datetime .datetime . fromtimestamp (
211+ cutoff_date = datetime .fromtimestamp (
165212 last_updated_tstamp , tz = timezone ("US/Pacific" )
166213 ).strftime ("%Y%m%d" )
167214
0 commit comments