Skip to content

Commit 024faf9

Browse files
committed
test=develop, fix
1 parent 5e1ef3a commit 024faf9

File tree

4 files changed

+172
-120
lines changed

4 files changed

+172
-120
lines changed

models/treebased/builder/tree_index_builder.py

Lines changed: 165 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,18 @@
1616
import numpy as np
1717
import struct
1818
import argparse
19+
import os
20+
import time
21+
import collections
22+
import multiprocessing as mp
23+
24+
from sklearn.cluster import KMeans
1925

2026

2127
class TreeIndexBuilder:
22-
def __init__(self, branch=2):
23-
self.branch = branch
28+
def __init__(self):
29+
self.branch = 2
30+
self.timeout = 5
2431

2532
def build_by_category(self, input_filename, output_filename):
2633
class Item:
@@ -64,20 +71,159 @@ def gen_code(start, end, code):
6471
gen_code(start, _sub_end, self.branch * code + self.branch - i)
6572
start = _sub_end
6673

67-
# mid = int((start + end) / 2)
68-
# gen_code(mid, end, 2 * code + 1)
69-
# gen_code(start, mid, 2 * code + 2)
70-
7174
gen_code(0, len(items), 0)
7275
ids = np.array([item.item_id for item in items])
7376
codes = np.array([item.code for item in items])
74-
# for i in range(len(items)):
75-
# print(ids[i], codes[i])
76-
#data = np.array([[] for i in range(len(ids))])
7777
self.build(output_filename, ids, codes)
7878

79-
def tree_init_by_kmeans(self):
80-
pass
79+
def tree_init_by_kmeans(self, input_filename, output_filename, parall=1):
80+
t1 = time.time()
81+
ids = list()
82+
data = list()
83+
with open(input_filename) as f:
84+
for line in f:
85+
arr = line.split(',')
86+
if not arr:
87+
break
88+
ids.append(int(arr[0]))
89+
vector = list()
90+
for i in range(1, len(arr)):
91+
vector.append(float(arr[i]))
92+
data.append(vector)
93+
self.ids = np.array(ids)
94+
self.data = np.array(data)
95+
t2 = time.time()
96+
print("Read data done, {} records read, elapsed: {}".format(
97+
len(ids), t2 - t1))
98+
99+
queue = mp.Queue()
100+
queue.put((0, np.array(range(len(self.ids)))))
101+
processes = []
102+
pipes = []
103+
for _ in range(parall):
104+
a, b = mp.Pipe()
105+
p = mp.Process(target=self._train, args=(b, queue))
106+
processes.append(p)
107+
pipes.append(a)
108+
p.start()
109+
110+
self.codes = np.zeros((len(self.ids), ), dtype=np.int64)
111+
for pipe in pipes:
112+
codes = pipe.recv()
113+
for i in range(len(codes)):
114+
if codes[i] > 0:
115+
self.codes[i] = codes[i]
116+
117+
for p in processes:
118+
p.join()
119+
120+
assert (queue.empty())
121+
self.build(output_filename, self.ids, self.codes, data=self.data)
122+
123+
def _train(self, pipe, queue):
124+
last_size = -1
125+
catch_time = 0
126+
processed = False
127+
code = np.zeros((len(self.ids), ), dtype=np.int64)
128+
while True:
129+
for _ in range(5):
130+
try:
131+
pcode, index = queue.get(timeout=self.timeout)
132+
except:
133+
index = None
134+
if index is not None:
135+
break
136+
137+
if index is None:
138+
if processed and (last_size <= 1024 or catch_time >= 3):
139+
print("Process {} exits".format(os.getpid()))
140+
break
141+
else:
142+
print("Got empty job, pid: {}, time: {}".format(os.getpid(
143+
), catch_time))
144+
catch_time += 1
145+
continue
146+
147+
processed = True
148+
catch_time = 0
149+
last_size = len(index)
150+
if last_size <= 1024:
151+
self._minbatch(pcode, index, code)
152+
else:
153+
tstart = time.time()
154+
left_index, right_index = self._cluster(index)
155+
if last_size > 1024:
156+
print("Train iteration done, pcode:{}, "
157+
"data size: {}, elapsed time: {}"
158+
.format(pcode, len(index), time.time() - tstart))
159+
self.timeout = int(0.4 * self.timeout + 0.6 * (time.time() -
160+
tstart))
161+
if self.timeout < 5:
162+
self.timeout = 5
163+
164+
if len(left_index) > 1:
165+
queue.put((2 * pcode + 1, left_index))
166+
167+
if len(right_index) > 1:
168+
queue.put((2 * pcode + 2, right_index))
169+
process_count = 0
170+
for c in code:
171+
if c > 0:
172+
process_count += 1
173+
print("Process {} process {} items".format(os.getpid(), process_count))
174+
pipe.send(code)
175+
176+
def _minbatch(self, pcode, index, code):
177+
dq = collections.deque()
178+
dq.append((pcode, index))
179+
batch_size = len(index)
180+
tstart = time.time()
181+
while dq:
182+
pcode, index = dq.popleft()
183+
184+
if len(index) == 2:
185+
code[index[0]] = 2 * pcode + 1
186+
code[index[1]] = 2 * pcode + 2
187+
continue
188+
189+
left_index, right_index = self._cluster(index)
190+
if len(left_index) > 1:
191+
dq.append((2 * pcode + 1, left_index))
192+
elif len(left_index) == 1:
193+
code[left_index] = 2 * pcode + 1
194+
195+
if len(right_index) > 1:
196+
dq.append((2 * pcode + 2, right_index))
197+
elif len(right_index) == 1:
198+
code[right_index] = 2 * pcode + 2
199+
200+
print("Minbatch, batch size: {}, elapsed: {}".format(
201+
batch_size, time.time() - tstart))
202+
203+
def _cluster(self, index):
204+
data = self.data[index]
205+
kmeans = KMeans(n_clusters=2, random_state=0).fit(data)
206+
labels = kmeans.labels_
207+
l_i = np.where(labels == 0)[0]
208+
r_i = np.where(labels == 1)[0]
209+
left_index = index[l_i]
210+
right_index = index[r_i]
211+
if len(right_index) - len(left_index) > 1:
212+
distances = kmeans.transform(data[r_i])
213+
left_index, right_index = self._rebalance(left_index, right_index,
214+
distances[:, 1])
215+
elif len(left_index) - len(right_index) > 1:
216+
distances = kmeans.transform(data[l_i])
217+
left_index, right_index = self._rebalance(right_index, left_index,
218+
distances[:, 0])
219+
220+
return left_index, right_index
221+
222+
def _rebalance(self, lindex, rindex, distances):
223+
sorted_index = rindex[np.argsort(distances)]
224+
idx = np.concatenate((lindex, sorted_index))
225+
mid = int(len(idx) / 2)
226+
return idx[mid:], idx[:mid]
81227

82228
def build(self, output_filename, ids, codes, data=None, id_offset=None):
83229
# process id offset
@@ -161,7 +307,11 @@ def _write_kv(self, fwr, message):
161307
if __name__ == '__main__':
162308
parser = argparse.ArgumentParser(description="TreeIndexBuiler")
163309
parser.add_argument(
164-
"--branch", required=False, type=int, default=2, help="tree branch.")
310+
"--parallel",
311+
required=False,
312+
type=int,
313+
default=12,
314+
help="parallel nums.")
165315
parser.add_argument(
166316
"--mode",
167317
required=True,
@@ -172,8 +322,8 @@ def _write_kv(self, fwr, message):
172322

173323
args = parser.parse_args()
174324
if args.mode == "by_category":
175-
builder = TreeIndexBuilder(args.branch)
325+
builder = TreeIndexBuilder()
176326
builder.build_by_category(args.input, args.output)
177327
elif args.mode == "by_kmeans":
178-
builder = TreeIndexBuilder(args.branch)
179-
builder.tree_init_by_category(args.input, args.output)
328+
builder = TreeIndexBuilder()
329+
builder.tree_init_by_kmeans(args.input, args.output, args.parallel)

models/treebased/tdm/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ runner:
3131
train_batch_size: 100 # 30000
3232
epochs: 5
3333
print_interval: 10 # 1000
34-
model_save_path: "tdm_demo_output"
34+
model_save_path: "output_model_tdm_demo"
3535

3636
# hyper parameters of user-defined network
3737
hyper_parameters:

models/treebased/tdm/config_ub.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,17 @@ runner:
2020

2121
model_path: "static_model.py"
2222
reader_type: "QueueDataset" # DataLoader / QueueDataset / RecDataset
23-
pipe_command: "python ub_reader.py"
23+
pipe_command: "python reader.py"
2424
dataset_debug: False
2525
split_file_list: False
2626

27-
train_data_dir: "../ub_data/debug_data"
28-
train_reader_path: "ub_reader"
27+
train_data_dir: "../ub_data/train_data"
28+
train_reader_path: "reader"
2929

30-
train_batch_size: 300
30+
train_batch_size: 30000
3131
epochs: 5
32-
print_interval: 10 # 1000
33-
model_save_path: "tdm_demo_ub"
32+
print_interval: 1000
33+
model_save_path: "output_model_tdm_ub"
3434

3535
# hyper parameters of user-defined network
3636
hyper_parameters:

models/treebased/tdm/ub_reader.py

Lines changed: 0 additions & 98 deletions
This file was deleted.

0 commit comments

Comments
 (0)