Skip to content

Commit 5e1ef3a

Browse files
committed
test=develop, add jtm
1 parent fe0ed2d commit 5e1ef3a

File tree

5 files changed

+473
-33
lines changed

5 files changed

+473
-33
lines changed

models/treebased/README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
# Paddle TDM解决方案
1+
# Paddle TDM系列模型解决方案
22

3-
本示例代码提供了基于PaddlePaddle实现的[TDM](https://arxiv.org/pdf/1801.02294.pdf)推荐搜索算法。TDM模型是为大规模推荐系统设计的、能承载任意先进模型来高效检索用户兴趣的推荐算法解决方案。该方案基于树结构,提出了一套对用户兴趣度量进行层次化建模与检索的方法论,使得系统能直接利高级深度学习模型在全库范围内检索用户兴趣。其基本原理是使用树结构对全库item进行索引,然后训练深度模型以支持树上的逐层检索,从而将大规模推荐中全库检索的复杂度由O(n)(n为所有item的量级)下降至O(log n)。
3+
本示例代码提供了基于PaddlePaddle实现的树模型推荐搜索算法,包括[TDM](https://arxiv.org/pdf/1801.02294.pdf)[JTM](https://arxiv.org/pdf/1902.07565.pdf)。树模型是为大规模推荐系统设计的、能承载任意先进模型来高效检索用户兴趣的推荐算法解决方案。该方案基于树结构,提出了一套对用户兴趣度量进行层次化建模与检索的方法论,使得系统能直接利高级深度学习模型在全库范围内检索用户兴趣。其基本原理是使用树结构对全库item进行索引,然后训练深度模型以支持树上的逐层检索,从而将大规模推荐中全库检索的复杂度由O(n)(n为所有item的量级)下降至O(log n)。
44

55

66
## 快速开始
77

8-
基于demo数据集,快速上手TDM模型,为您后续设计适合特定使用场景的模型做准备。
8+
基于demo数据集,快速上手TDM系列模型,为您后续设计适合特定使用场景的模型做准备。
99

1010
假定您PaddleRec所在目录为${PaddleRec_Home}。
1111

@@ -21,15 +21,15 @@ demo数据集预处理一键命令为 `./data_prepare.sh demo` 。若对具体
2121
├── treebased
2222
├── demo_data
2323
| ├── samples JTM Tree-Learning算法所需,
24-
| | ├── samples_{item_id}.json 记录了所有和 `item_id` 相关的训练集样本
24+
| | ├── samples_{item_id}.json 记录了所有和item_id相关的训练集样本
2525
| ├── train_data 训练集目录
2626
| ├── test_data 测试集目录
2727
| ├── ItemCate.txt 记录所有item的类别信息,用于初始化建树。
2828
| ├── Stat.txt 记录所有item在训练集中出现的频次信息,用于采样。
29-
| ├── tree.pb 初始化化树文件
29+
| ├── tree.pb 预处理后,生成的初始化树文件
3030
```
3131

32-
- Step2: 快速运行。config.yaml中配置了模型训练所有的超参,运行方式同PaddleRec其他模型静态图运行方式。当前树模型暂不支持动态图运行模式。
32+
- Step2: TDM快速运行。config.yaml中配置了模型训练所有的超参,运行方式同PaddleRec其他模型静态图运行方式。当前树模型暂不支持动态图运行模式。
3333

3434
```shell
3535
python -u ../../../tools/static_trainer.py -m config.yaml
Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
from paddle.distributed.fleet.dataset import TreeIndex
17+
import paddle.fluid as fluid
18+
from paddle.fluid.framework import Program
19+
import numpy as np
20+
import random
21+
import os
22+
import sys
23+
import multiprocessing as mp
24+
import json
25+
import time
26+
import math
27+
import argparse
28+
from user_preference import UserPreferenceModel
29+
30+
paddle.enable_static()
31+
32+
33+
def mp_run(data, process_num, func, *args):
34+
""" run func with multi process
35+
"""
36+
level_start = time.time()
37+
partn = max(len(data) / process_num, 1)
38+
start = 0
39+
p_idx = 0
40+
ps = []
41+
while start < len(data):
42+
local_data = data[start:start + partn]
43+
start += partn
44+
p = mp.Process(target=func, args=(local_data, p_idx) + args)
45+
ps.append(p)
46+
p.start()
47+
p_idx += 1
48+
for p in ps:
49+
p.join()
50+
51+
for p in ps:
52+
p.terminate()
53+
return p_idx
54+
55+
56+
def get_itemset_given_ancestor(pi_new, node):
57+
res = []
58+
for ci, code in pi_new.items():
59+
if code == node:
60+
res.append(ci)
61+
return res
62+
63+
64+
# you need to define your sample_set
65+
def get_sample_set(ck, args):
66+
if not os.path.exists("{}/samples_{}.json".format(args.sample_directory,
67+
ck)):
68+
return []
69+
with open("{}/samples_{}.json".format(args.sample_directory, ck),
70+
'r') as f:
71+
all_samples = json.load(f)
72+
73+
sample_nums = args.sample_nums
74+
if sample_nums > 0:
75+
size = len(all_samples)
76+
if (size > sample_nums):
77+
sample_set = np.random.choice(
78+
range(size), size=sample_nums, replace=False).tolist()
79+
return [all_samples[s] for s in sample_set]
80+
else:
81+
return all_samples
82+
83+
84+
def get_weights(C_ni, idx, edge_weights, ni, children_of_ni_in_level_l, tree,
85+
args):
86+
"""use the user preference prediction model to calculate the required weights
87+
88+
Returns:
89+
all weights
90+
91+
Args:
92+
C_ni (item, required): item set whose ancestor is the non-leaf node ni
93+
ni (node, required): a non-leaf node in level l-d
94+
children_of_ni_in_level_l (list, required): the level l-th children of ni
95+
tree (tree, required): the old tree (\pi_{old})
96+
97+
"""
98+
#print("begin idx: {}, C_ni: {}.".format(idx, len(C_ni)))
99+
tree_emb_size = tree.emb_size()
100+
#print("tree_emb_size: ", tree_emb_size)
101+
prediction_model = UserPreferenceModel(args.init_model_path, tree_emb_size,
102+
args.node_emb_size)
103+
104+
for ck in C_ni:
105+
_weights = list()
106+
# the first element is the list of nodes in level l
107+
_weights.append([])
108+
# the second element is the list of corresponding weights
109+
_weights.append([])
110+
111+
samples = get_sample_set(ck, args)
112+
print(samples)
113+
for node in children_of_ni_in_level_l:
114+
path_to_ni = tree.get_travel_path(node, ni)
115+
if len(samples) == 0:
116+
weight = 0.0
117+
else:
118+
weight = prediction_model.calc_prediction_weight(samples,
119+
path_to_ni)
120+
121+
_weights[0].append(node)
122+
_weights[1].append(weight)
123+
edge_weights.update({ck: _weights})
124+
125+
126+
# print("end idx: {}, C_ni: {}, edge_weights: {}.".format(idx, len(C_ni), len(edge_weights)))
127+
128+
129+
def assign_parent(tree, l_max, l, d, ni, C_ni, args):
130+
"""implementation of line 5 of Algorithm 2
131+
132+
Returns:
133+
updated \pi_{new}
134+
135+
Args:
136+
l_max (int, required): the max level of the tree
137+
l (int, required): current assign level
138+
d (int, required): level gap in tree_learning
139+
ni (node, required): a non-leaf node in level l-d
140+
C_ni (item, required): item set whose ancestor is the non-leaf node ni
141+
tree (tree, required): the old tree (\pi_{old})
142+
"""
143+
# get the children of ni in level l
144+
children_of_ni_in_level_l = tree.get_children_codes(ni, l)
145+
146+
print(children_of_ni_in_level_l)
147+
# get all the required weights
148+
edge_weights = mp.Manager().dict()
149+
150+
mp_run(C_ni, 12, get_weights, edge_weights, ni, children_of_ni_in_level_l,
151+
tree, args)
152+
153+
print("finish calculate edge_weights. {}.".format(len(edge_weights)))
154+
# assign each item to the level l node with the maximum weight
155+
assign_dict = dict()
156+
for ci, info in edge_weights.items():
157+
assign_candidate_nodes = np.array(info[0], dtype=np.int64)
158+
assign_weights = np.array(info[1], dtype=np.float32)
159+
sorted_idx = np.argsort(-assign_weights)
160+
sorted_weights = assign_weights[sorted_idx]
161+
sorted_candidate_nodes = assign_candidate_nodes[sorted_idx]
162+
# assign item ci to the node with the largest weight
163+
max_weight_node = sorted_candidate_nodes[0]
164+
if max_weight_node in assign_dict:
165+
assign_dict[max_weight_node].append(
166+
(ci, 0, sorted_candidate_nodes, sorted_weights))
167+
else:
168+
assign_dict[max_weight_node] = [
169+
(ci, 0, sorted_candidate_nodes, sorted_weights)
170+
]
171+
172+
edge_weights = None
173+
174+
# get each item's original assignment of level l in tree, used in rebalance process
175+
origin_relation = tree.get_pi_relation(C_ni, l)
176+
# for ci in C_ni:
177+
# origin_relation[ci] = self._tree.get_ancestor(ci, l)
178+
179+
# rebalance
180+
max_assign_num = int(math.pow(2, l_max - l))
181+
processed_set = set()
182+
183+
while True:
184+
max_assign_cnt = 0
185+
max_assign_node = None
186+
187+
for node in children_of_ni_in_level_l:
188+
if node in processed_set:
189+
continue
190+
if node not in assign_dict:
191+
continue
192+
if len(assign_dict[node]) > max_assign_cnt:
193+
max_assign_cnt = len(assign_dict[node])
194+
max_assign_node = node
195+
196+
if max_assign_node == None or max_assign_cnt <= max_assign_num:
197+
break
198+
199+
# rebalance
200+
processed_set.add(max_assign_node)
201+
elements = assign_dict[max_assign_node]
202+
elements.sort(
203+
key=lambda x: (int(max_assign_node != origin_relation[x[0]]), -x[3][x[1]])
204+
)
205+
for e in elements[max_assign_num:]:
206+
idx = e[1] + 1
207+
while idx < len(e[2]):
208+
other_parent_node = e[2][idx]
209+
if other_parent_node in processed_set:
210+
idx += 1
211+
continue
212+
if other_parent_node not in assign_dict:
213+
assign_dict[other_parent_node] = [(e[0], idx, e[2], e[3])]
214+
else:
215+
assign_dict[other_parent_node].append(
216+
(e[0], idx, e[2], e[3]))
217+
break
218+
219+
del elements[max_assign_num:]
220+
221+
pi_new = dict()
222+
for parent_code, value in assign_dict.items():
223+
max_assign_num = int(math.pow(2, l_max - l))
224+
assert len(value) <= max_assign_num
225+
for e in value:
226+
assert e[0] not in pi_new
227+
pi_new[e[0]] = parent_code
228+
229+
return pi_new
230+
231+
232+
def process(nodes, idx, pi_new_final, tree, l, d, args):
233+
l_max = tree.height() - 1
234+
for ni in nodes:
235+
C_ni = get_itemset_given_ancestor(pi_new_final, ni)
236+
print("begin to handle {}, have {} items.".format(ni, len(C_ni)))
237+
if len(C_ni) == 0:
238+
continue
239+
pi_star = assign_parent(tree, l_max, l, d, ni, C_ni, args)
240+
print(pi_star)
241+
# update pi_new according to the found optimal pi_star
242+
for item, node in pi_star.items():
243+
pi_new_final.update({item: node})
244+
print("end to handle {}.".format(ni))
245+
246+
247+
def tree_learning(args):
248+
tree = TreeIndex(args.tree_name, args.tree_path)
249+
d = args.gap
250+
251+
l_max = tree.height() - 1
252+
l = d
253+
254+
pi_new = dict()
255+
256+
all_items = [node.id() for node in tree.get_all_leafs()]
257+
pi_new = tree.get_pi_relation(all_items, l - d)
258+
259+
pi_new_final = mp.Manager().dict()
260+
pi_new_final.update(pi_new)
261+
262+
del all_items
263+
del pi_new
264+
265+
while d > 0:
266+
print("begin to re-assign {} layer by {} layer.".format(l, l - d))
267+
nodes = tree.get_layer_codes(l - d)
268+
real_process_num = mp_run(nodes, 12, process, pi_new_final, tree, l, d,
269+
args)
270+
d = min(d, l_max - l)
271+
l = l + d
272+
print(pi_new_final)
273+
274+
275+
if __name__ == '__main__':
276+
_PARSER = argparse.ArgumentParser(description="Tree Learning Algorith.")
277+
_PARSER.add_argument("--tree_name", required=True, help="tree name.")
278+
_PARSER.add_argument("--tree_path", required=True, help="tree path.")
279+
_PARSER.add_argument(
280+
"--sample_directory", required=True, help="samples directory")
281+
_PARSER.add_argument(
282+
"--output_filename", default="./output.pb", help="new tree filename.")
283+
_PARSER.add_argument("--gap", type=int, default=7, help="gap.")
284+
_PARSER.add_argument(
285+
"--node_emb_size", type=int, default=64, help="node embedding size.")
286+
_PARSER.add_argument(
287+
"--sample_nums",
288+
type=int,
289+
default=-1,
290+
help="sample nums. default value is -1, means use all related train samples."
291+
)
292+
_PARSER.add_argument(
293+
"--init_model_path", type=str, default="", help="model path.")
294+
args = _PARSER.parse_args()
295+
296+
tree_learning(args)

0 commit comments

Comments
 (0)