1616import numpy as np
1717import struct
1818import argparse
19+ import os
20+ import time
21+ import collections
22+ import multiprocessing as mp
23+
24+ from sklearn .cluster import KMeans
1925
2026
2127class TreeIndexBuilder :
22- def __init__ (self , branch = 2 ):
23- self .branch = branch
28+ def __init__ (self ):
29+ self .branch = 2
30+ self .timeout = 5
2431
2532 def build_by_category (self , input_filename , output_filename ):
2633 class Item :
@@ -64,20 +71,159 @@ def gen_code(start, end, code):
6471 gen_code (start , _sub_end , self .branch * code + self .branch - i )
6572 start = _sub_end
6673
67- # mid = int((start + end) / 2)
68- # gen_code(mid, end, 2 * code + 1)
69- # gen_code(start, mid, 2 * code + 2)
70-
7174 gen_code (0 , len (items ), 0 )
7275 ids = np .array ([item .item_id for item in items ])
7376 codes = np .array ([item .code for item in items ])
74- # for i in range(len(items)):
75- # print(ids[i], codes[i])
76- #data = np.array([[] for i in range(len(ids))])
7777 self .build (output_filename , ids , codes )
7878
79- def tree_init_by_kmeans (self ):
80- pass
79+ def tree_init_by_kmeans (self , input_filename , output_filename , parall = 1 ):
80+ t1 = time .time ()
81+ ids = list ()
82+ data = list ()
83+ with open (input_filename ) as f :
84+ for line in f :
85+ arr = line .split (',' )
86+ if not arr :
87+ break
88+ ids .append (int (arr [0 ]))
89+ vector = list ()
90+ for i in range (1 , len (arr )):
91+ vector .append (float (arr [i ]))
92+ data .append (vector )
93+ self .ids = np .array (ids )
94+ self .data = np .array (data )
95+ t2 = time .time ()
96+ print ("Read data done, {} records read, elapsed: {}" .format (
97+ len (ids ), t2 - t1 ))
98+
99+ queue = mp .Queue ()
100+ queue .put ((0 , np .array (range (len (self .ids )))))
101+ processes = []
102+ pipes = []
103+ for _ in range (parall ):
104+ a , b = mp .Pipe ()
105+ p = mp .Process (target = self ._train , args = (b , queue ))
106+ processes .append (p )
107+ pipes .append (a )
108+ p .start ()
109+
110+ self .codes = np .zeros ((len (self .ids ), ), dtype = np .int64 )
111+ for pipe in pipes :
112+ codes = pipe .recv ()
113+ for i in range (len (codes )):
114+ if codes [i ] > 0 :
115+ self .codes [i ] = codes [i ]
116+
117+ for p in processes :
118+ p .join ()
119+
120+ assert (queue .empty ())
121+ self .build (output_filename , self .ids , self .codes , data = self .data )
122+
123+ def _train (self , pipe , queue ):
124+ last_size = - 1
125+ catch_time = 0
126+ processed = False
127+ code = np .zeros ((len (self .ids ), ), dtype = np .int64 )
128+ while True :
129+ for _ in range (5 ):
130+ try :
131+ pcode , index = queue .get (timeout = self .timeout )
132+ except :
133+ index = None
134+ if index is not None :
135+ break
136+
137+ if index is None :
138+ if processed and (last_size <= 1024 or catch_time >= 3 ):
139+ print ("Process {} exits" .format (os .getpid ()))
140+ break
141+ else :
142+ print ("Got empty job, pid: {}, time: {}" .format (os .getpid (
143+ ), catch_time ))
144+ catch_time += 1
145+ continue
146+
147+ processed = True
148+ catch_time = 0
149+ last_size = len (index )
150+ if last_size <= 1024 :
151+ self ._minbatch (pcode , index , code )
152+ else :
153+ tstart = time .time ()
154+ left_index , right_index = self ._cluster (index )
155+ if last_size > 1024 :
156+ print ("Train iteration done, pcode:{}, "
157+ "data size: {}, elapsed time: {}"
158+ .format (pcode , len (index ), time .time () - tstart ))
159+ self .timeout = int (0.4 * self .timeout + 0.6 * (time .time () -
160+ tstart ))
161+ if self .timeout < 5 :
162+ self .timeout = 5
163+
164+ if len (left_index ) > 1 :
165+ queue .put ((2 * pcode + 1 , left_index ))
166+
167+ if len (right_index ) > 1 :
168+ queue .put ((2 * pcode + 2 , right_index ))
169+ process_count = 0
170+ for c in code :
171+ if c > 0 :
172+ process_count += 1
173+ print ("Process {} process {} items" .format (os .getpid (), process_count ))
174+ pipe .send (code )
175+
176+ def _minbatch (self , pcode , index , code ):
177+ dq = collections .deque ()
178+ dq .append ((pcode , index ))
179+ batch_size = len (index )
180+ tstart = time .time ()
181+ while dq :
182+ pcode , index = dq .popleft ()
183+
184+ if len (index ) == 2 :
185+ code [index [0 ]] = 2 * pcode + 1
186+ code [index [1 ]] = 2 * pcode + 2
187+ continue
188+
189+ left_index , right_index = self ._cluster (index )
190+ if len (left_index ) > 1 :
191+ dq .append ((2 * pcode + 1 , left_index ))
192+ elif len (left_index ) == 1 :
193+ code [left_index ] = 2 * pcode + 1
194+
195+ if len (right_index ) > 1 :
196+ dq .append ((2 * pcode + 2 , right_index ))
197+ elif len (right_index ) == 1 :
198+ code [right_index ] = 2 * pcode + 2
199+
200+ print ("Minbatch, batch size: {}, elapsed: {}" .format (
201+ batch_size , time .time () - tstart ))
202+
203+ def _cluster (self , index ):
204+ data = self .data [index ]
205+ kmeans = KMeans (n_clusters = 2 , random_state = 0 ).fit (data )
206+ labels = kmeans .labels_
207+ l_i = np .where (labels == 0 )[0 ]
208+ r_i = np .where (labels == 1 )[0 ]
209+ left_index = index [l_i ]
210+ right_index = index [r_i ]
211+ if len (right_index ) - len (left_index ) > 1 :
212+ distances = kmeans .transform (data [r_i ])
213+ left_index , right_index = self ._rebalance (left_index , right_index ,
214+ distances [:, 1 ])
215+ elif len (left_index ) - len (right_index ) > 1 :
216+ distances = kmeans .transform (data [l_i ])
217+ left_index , right_index = self ._rebalance (right_index , left_index ,
218+ distances [:, 0 ])
219+
220+ return left_index , right_index
221+
222+ def _rebalance (self , lindex , rindex , distances ):
223+ sorted_index = rindex [np .argsort (distances )]
224+ idx = np .concatenate ((lindex , sorted_index ))
225+ mid = int (len (idx ) / 2 )
226+ return idx [mid :], idx [:mid ]
81227
82228 def build (self , output_filename , ids , codes , data = None , id_offset = None ):
83229 # process id offset
@@ -161,7 +307,11 @@ def _write_kv(self, fwr, message):
161307if __name__ == '__main__' :
162308 parser = argparse .ArgumentParser (description = "TreeIndexBuiler" )
163309 parser .add_argument (
164- "--branch" , required = False , type = int , default = 2 , help = "tree branch." )
310+ "--parallel" ,
311+ required = False ,
312+ type = int ,
313+ default = 12 ,
314+ help = "parallel nums." )
165315 parser .add_argument (
166316 "--mode" ,
167317 required = True ,
@@ -172,8 +322,8 @@ def _write_kv(self, fwr, message):
172322
173323 args = parser .parse_args ()
174324 if args .mode == "by_category" :
175- builder = TreeIndexBuilder (args . branch )
325+ builder = TreeIndexBuilder ()
176326 builder .build_by_category (args .input , args .output )
177327 elif args .mode == "by_kmeans" :
178- builder = TreeIndexBuilder (args . branch )
179- builder .tree_init_by_category (args .input , args .output )
328+ builder = TreeIndexBuilder ()
329+ builder .tree_init_by_kmeans (args .input , args .output , args . parallel )
0 commit comments