Skip to content

Commit 95de130

Browse files
committed
test=develop, add infer and get_leaf_embedding
1 parent 024faf9 commit 95de130

File tree

5 files changed

+368
-1
lines changed

5 files changed

+368
-1
lines changed

models/treebased/README.md

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,28 @@ demo数据集预处理一键命令为 `./data_prepare.sh demo` 。若对具体
2929
| ├── tree.pb 预处理后,生成的初始化树文件
3030
```
3131

32-
- Step2: TDM快速运行。config.yaml中配置了模型训练所有的超参,运行方式同PaddleRec其他模型静态图运行方式。当前树模型暂不支持动态图运行模式。
32+
- Step2: 训练。config.yaml中配置了模型训练所有的超参,运行方式同PaddleRec其他模型静态图运行方式。当前树模型暂不支持动态图运行模式。
3333

3434
```shell
3535
python -u ../../../tools/static_trainer.py -m config.yaml
3636
```
37+
38+
- Step3: 预测,命令如下所示。其中第一个参数为训练config.yaml位置,第二个参数为预测模型地址。
39+
40+
```
41+
python infer.py config.yaml ./output_model_tdm_demo/0/
42+
```
43+
44+
- Step4: 提取Item(叶子节点)的Embedding,用于重新建树,开始下一轮训练。命令如下所示,其中第一个参数为训练config.yaml位置,第二个参数模型地址,第三个参数为输出文件名称。
45+
46+
```
47+
python get_leaf_embedding.py config.yaml ./output_model_tdm_demo/0/ epoch_0_item_embedding.txt
48+
```
49+
50+
- Step5: 基于Step4得到的Item的Embedding,重新建树。命令如下所示。
51+
52+
```
53+
cd ../builder && python tree_index_builder.py --mode by_kmeans --input epoch_0_item_embedding.txt --output new_tree.pb
54+
```
55+
56+
- Step6: 修改config.yaml中tree文件的路径为最新tree.pb,返回Step2,开始新一轮的训练。

models/treebased/tdm/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,4 @@ hyper_parameters:
5454
tree_name: "demo"
5555
tree_path: "../demo_data/tree.pb"
5656
with_hierachy: True
57+
topk: 100

models/treebased/tdm/config_ub.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,4 @@ hyper_parameters:
5353
tree_name: "ub"
5454
tree_path: "../ub_data/tree.pb"
5555
with_hierachy: False
56+
topk: 200
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import paddle.fluid as fluid
17+
import numpy as np
18+
import io
19+
import sys
20+
from paddle.distributed.fleet.dataset import TreeIndex
21+
import os
22+
23+
paddle.enable_static()
24+
25+
26+
def get_emb_numpy(tree_node_num, node_emb_size, init_model_path=""):
27+
all_nodes = fluid.layers.data(
28+
name="all_nodes",
29+
shape=[-1, 1],
30+
dtype="int64",
31+
lod_level=1, )
32+
33+
output = fluid.layers.embedding(
34+
input=all_nodes,
35+
is_sparse=True,
36+
size=[tree_node_num, node_emb_size],
37+
param_attr=fluid.ParamAttr(
38+
name="TDM_Tree_Emb",
39+
initializer=paddle.fluid.initializer.UniformInitializer()))
40+
41+
place = fluid.CPUPlace()
42+
exe = fluid.Executor(place)
43+
44+
exe.run(fluid.default_startup_program())
45+
if init_model_path != "":
46+
fluid.io.load_persistables(exe, init_model_path)
47+
48+
return np.array(fluid.global_scope().find_var("TDM_Tree_Emb").get_tensor())
49+
50+
51+
if __name__ == '__main__':
52+
utils_path = "{}/tools/utils/static_ps".format(
53+
os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd()))))
54+
sys.path.append(utils_path)
55+
print(utils_path)
56+
import common
57+
58+
yaml_helper = common.YamlHelper()
59+
config = yaml_helper.load_yaml(sys.argv[1])
60+
61+
tree_name = config.get("hyper_parameters.tree_name")
62+
tree_path = config.get("hyper_parameters.tree_path")
63+
tree_node_num = config.get("hyper_parameters.sparse_feature_num")
64+
node_emb_size = config.get("hyper_parameters.node_emb_size")
65+
66+
tensor = get_emb_numpy(tree_node_num, node_emb_size, sys.argv[2])
67+
68+
tree = TreeIndex(tree_name, tree_path)
69+
all_leafs = tree.get_all_leafs()
70+
71+
with open(sys.argv[3], 'w') as fout:
72+
for node in all_leafs:
73+
node_id = node.id()
74+
emb_vec = map(str, tensor[node_id].tolist())
75+
emb_vec = [str(node_id)] + emb_vec
76+
fout.write(",".join(emb_vec))
77+
fout.write("\n")

models/treebased/tdm/infer.py

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from model import dnn_model_define
16+
import paddle
17+
import paddle.fluid as fluid
18+
import os
19+
import time
20+
import numpy as np
21+
import multiprocessing as mp
22+
import sys
23+
from paddle.distributed.fleet.dataset import TreeIndex
24+
25+
paddle.enable_static()
26+
27+
28+
class Reader():
29+
def __init__(self, item_nums):
30+
self.item_nums = item_nums
31+
32+
def line_process(self, line):
33+
#378254_6|378254_6|train_unit_id@1045081:1.0;item_59@4856095;item_65@1821603;item_64@3598037;item_67@3423855;item_66@3598037;item_61@596274;item_60@3885113;item_63@3338392;item_62@643355;item_69@4951278;item_68@3308390||1.0|
34+
features = line.strip().split("|")[2].split(";")
35+
groudtruth = []
36+
output_list = [0] * self.item_nums
37+
for item in features:
38+
f = item.split("@")
39+
if f[0] == "test_unit_id":
40+
groudtruth = f[1].split(",")
41+
groudtruth = [int(g.split(":")[0]) for g in groudtruth]
42+
else:
43+
output_list[int(f[0].split('_')[1]) - 1] = int(f[1])
44+
45+
return groudtruth, output_list
46+
47+
def dataloader(self, file_list):
48+
"DataLoader Pyreader Generator"
49+
50+
def reader():
51+
for file in file_list:
52+
with open(file, 'r') as f:
53+
for line in f:
54+
groudtruth, output_list = self.line_process(line)
55+
yield groudtruth, output_list
56+
57+
return reader
58+
59+
60+
def net_input(item_nums=69):
61+
user_input = [
62+
paddle.static.data(
63+
name="item_" + str(i + 1), shape=[None, 1], dtype="int64")
64+
for i in range(item_nums)
65+
]
66+
67+
item = paddle.static.data(name="unit_id", shape=[None, 1], dtype="int64")
68+
69+
return user_input + [item]
70+
71+
72+
def mp_run(data, process_num, func, *args):
73+
""" run func with multi process
74+
"""
75+
level_start = time.time()
76+
partn = max(len(data) / process_num, 1)
77+
start = 0
78+
p_idx = 0
79+
ps = []
80+
manager = mp.Manager()
81+
res = manager.dict()
82+
while start < len(data):
83+
local_data = data[start:start + partn]
84+
start += partn
85+
p = mp.Process(target=func, args=(res, local_data, p_idx) + args)
86+
ps.append(p)
87+
p.start()
88+
p_idx += 1
89+
for p in ps:
90+
p.join()
91+
92+
for p in ps:
93+
p.terminate()
94+
95+
total_precision_rate = 0.0
96+
total_recall_rate = 0.0
97+
total_nums = 0
98+
for i in range(p_idx):
99+
print(i)
100+
total_recall_rate += res["{}_recall".format(i)]
101+
total_precision_rate += res["{}_precision".format(i)]
102+
total_nums += res["{}_nums".format(i)]
103+
print("global recall rate: {} / {} = {}".format(
104+
total_recall_rate, total_nums, total_recall_rate / float(total_nums)))
105+
print("global precision rate: {} / {} = {}".format(
106+
total_precision_rate, total_nums, total_precision_rate / float(
107+
total_nums)))
108+
109+
return p_idx
110+
111+
112+
def load_tree_info(name, path, topk=200):
113+
tree = TreeIndex(name, path)
114+
all_codes = []
115+
first_layer_code = None
116+
for i in range(tree.height()):
117+
layer_codes = tree.get_layer_codes(i)
118+
if len(layer_codes) > topk and first_layer_code == None:
119+
first_layer_code = layer_codes
120+
all_codes += layer_codes
121+
all_ids = tree.get_nodes(all_codes)
122+
id_code_map = {}
123+
code_id_map = {}
124+
for i in range(len(all_codes)):
125+
id = all_ids[i].id()
126+
code = all_codes[i]
127+
id_code_map[id] = code
128+
code_id_map[code] = id
129+
print(len(all_codes), len(all_ids), len(id_code_map), len(code_id_map))
130+
131+
first_layer = tree.get_nodes(first_layer_code)
132+
first_layer = [node.id() for node in first_layer]
133+
134+
return id_code_map, code_id_map, tree.branch(), first_layer
135+
136+
137+
def infer(res_dict, filelist, process_idx, init_model_path, id_code_map,
138+
code_id_map, branch, first_layer_set, config):
139+
print(process_idx, filelist, init_model_path)
140+
item_nums = config.get("hyper_parameters.item_nums", 69)
141+
topk = config.get("hyper_parameters.topk", 200)
142+
node_nums = config.get("hyper_parameters.sparse_feature_num")
143+
node_emb_size = config.get("hyper_parameters.node_emb_size")
144+
input = net_input(item_nums)
145+
146+
embedding = paddle.nn.Embedding(
147+
node_nums,
148+
node_emb_size,
149+
sparse=True,
150+
weight_attr=paddle.framework.ParamAttr(
151+
name="TDM_Tree_Emb",
152+
initializer=paddle.nn.initializer.Normal(std=0.001)))
153+
154+
user_feature = input[0:item_nums]
155+
user_feature_emb = list(map(embedding, user_feature)) # [(bs, emb)]
156+
157+
unit_id_emb = embedding(input[-1])
158+
dout = dnn_model_define(user_feature_emb, unit_id_emb)
159+
160+
softmax_prob = paddle.nn.functional.softmax(dout)
161+
positive_prob = paddle.slice(softmax_prob, axes=[1], starts=[1], ends=[2])
162+
prob_re = paddle.reshape(positive_prob, [-1])
163+
164+
_, topk_i = paddle.topk(prob_re, k=topk)
165+
topk_node = paddle.index_select(input[-1], topk_i)
166+
167+
with open("main_program", 'w') as f:
168+
f.write(str(paddle.static.default_main_program()))
169+
170+
exe = paddle.static.Executor(fluid.CPUPlace())
171+
exe.run(paddle.static.default_startup_program())
172+
173+
print("begin to load parameters")
174+
fluid.io.load_persistables(exe, dirname=init_model_path)
175+
print("end load parameters")
176+
reader_instance = Reader(item_nums)
177+
reader = reader_instance.dataloader(filelist)
178+
179+
total_recall_rate = 0.0
180+
total_precision_rate = 0.0
181+
total_nums = 0
182+
child_info = dict()
183+
for groudtruth, user_input in reader():
184+
total_nums += 1
185+
186+
recall_result = []
187+
candidate = first_layer_set
188+
189+
idx = 8
190+
while (len(recall_result) < topk):
191+
idx += 1
192+
feed_dict = {}
193+
for i in range(1, 70):
194+
feed_dict['item_' + str(i)] = np.ones(
195+
shape=[len(candidate), 1],
196+
dtype='int64') * user_input[i - 1]
197+
feed_dict['unit_id'] = np.array(
198+
candidate, dtype='int64').reshape(-1, 1)
199+
200+
res = exe.run(program=paddle.static.default_main_program(),
201+
feed=feed_dict,
202+
fetch_list=[topk_node.name])
203+
topk_node_res = res[0].reshape([-1]).tolist()
204+
205+
candidate = []
206+
for i in range(len(topk_node_res)):
207+
node = topk_node_res[i]
208+
if node not in child_info:
209+
child_info[node] = []
210+
node_code = id_code_map[node]
211+
for j in range(1, branch + 1):
212+
child_code = node_code * branch + j
213+
if child_code in code_id_map:
214+
child_info[node].append(code_id_map[child_code])
215+
216+
if len(child_info[node]) == 0:
217+
recall_result.append(node)
218+
else:
219+
candidate = candidate + child_info[node]
220+
221+
recall_result = recall_result[:topk]
222+
intersec = list(set(recall_result).intersection(set(groudtruth)))
223+
total_recall_rate += float(len(intersec)) / float(len(groudtruth))
224+
total_precision_rate += float(len(intersec)) / float(
225+
len(recall_result))
226+
227+
if (total_nums % 100 == 0):
228+
print("global recall rate: {} / {} = {}".format(
229+
total_recall_rate, total_nums, total_recall_rate / float(
230+
total_nums)))
231+
print("global precision rate: {} / {} = {}".format(
232+
total_precision_rate, total_nums, total_precision_rate / float(
233+
total_nums)))
234+
res_dict["{}_recall".format(process_idx)] = total_recall_rate
235+
res_dict["{}_precision".format(process_idx)] = total_precision_rate
236+
res_dict["{}_nums".format(process_idx)] = total_nums
237+
print("process idx:{}, global recall rate: {} / {} = {}".format(
238+
process_idx, total_recall_rate, total_nums, total_recall_rate / float(
239+
total_nums)))
240+
print("process idx:{}, global precision rate: {} / {} = {}".format(
241+
process_idx, total_precision_rate, total_nums, total_precision_rate /
242+
float(total_nums)))
243+
244+
245+
if __name__ == '__main__':
246+
utils_path = "{}/tools/utils/static_ps".format(
247+
os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd()))))
248+
sys.path.append(utils_path)
249+
print(utils_path)
250+
import common
251+
yaml_helper = common.YamlHelper()
252+
config = yaml_helper.load_yaml(sys.argv[1])
253+
254+
test_files_path = "./test_data"
255+
filelist = [
256+
"{}/{}".format(test_files_path, x) for x in os.listdir(test_files_path)
257+
]
258+
print(filelist)
259+
init_model_path = sys.argv[2]
260+
print(init_model_path)
261+
tree_name = config.get("hyper_parameters.tree_name")
262+
tree_path = config.get("hyper_parameters.tree_path")
263+
print("tree_name: {}".format(tree_name))
264+
print("tree_path: {}".format(tree_path))
265+
id_code_map, code_id_map, branch, first_layer_set = load_tree_info(
266+
tree_name, tree_path)
267+
mp_run(filelist, 12, infer, init_model_path, id_code_map, code_id_map,
268+
branch, first_layer_set, config)

0 commit comments

Comments
 (0)