Skip to content

Commit 3e028e1

Browse files
committed
test=develop, add tdm
1 parent 50e5368 commit 3e028e1

File tree

11 files changed

+1274
-0
lines changed

11 files changed

+1274
-0
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from paddle.fluid.proto import index_dataset_pb2
16+
import numpy as np
17+
import struct
18+
import argparse
19+
20+
21+
class TreeIndexBuilder:
22+
def __init__(self, branch=2):
23+
self.branch = branch
24+
25+
def build_by_category(self, input_filename, output_filename):
26+
class Item:
27+
def __init__(self, item_id, cat_id):
28+
self.item_id = item_id
29+
self.cat_id = cat_id
30+
self.code = 0
31+
32+
def __lt__(self, other):
33+
return self.cat_id < other.cat_id or \
34+
(self.cat_id == other.cat_id and
35+
self.item_id < other.item_id)
36+
37+
items = []
38+
item_id_set = set()
39+
with open(input_filename, 'r') as f:
40+
for line in f:
41+
iterobj = line.split()
42+
item_id = int(iterobj[0])
43+
cat_id = int(iterobj[1])
44+
if item_id not in item_id_set:
45+
items.append(Item(item_id, cat_id))
46+
item_id_set.add(item_id)
47+
del item_id_set
48+
items.sort()
49+
50+
def gen_code(start, end, code):
51+
if end <= start:
52+
return
53+
if end == start + 1:
54+
items[start].code = code
55+
return
56+
num = int((end - start) / self.branch)
57+
remain = int((end - start) % self.branch)
58+
for i in range(self.branch):
59+
_sub_end = start + (i + 1) * num
60+
if (remain > 0):
61+
remain -= 1
62+
_sub_end += 1
63+
_sub_end = min(_sub_end, end)
64+
gen_code(start, _sub_end, self.branch * code + self.branch - i)
65+
start = _sub_end
66+
67+
# mid = int((start + end) / 2)
68+
# gen_code(mid, end, 2 * code + 1)
69+
# gen_code(start, mid, 2 * code + 2)
70+
71+
gen_code(0, len(items), 0)
72+
ids = np.array([item.item_id for item in items])
73+
codes = np.array([item.code for item in items])
74+
# for i in range(len(items)):
75+
# print(ids[i], codes[i])
76+
#data = np.array([[] for i in range(len(ids))])
77+
self.build(output_filename, ids, codes)
78+
79+
def tree_init_by_kmeans(self):
80+
pass
81+
82+
def build(self, output_filename, ids, codes, data=None, id_offset=None):
83+
# process id offset
84+
if not id_offset:
85+
max_id = 0
86+
for id in ids:
87+
if id > max_id:
88+
max_id = id
89+
id_offset = max_id + 1
90+
91+
# sort by codes
92+
argindex = np.argsort(codes)
93+
codes = codes[argindex]
94+
ids = ids[argindex]
95+
96+
# Trick, make all leaf nodes to be in same level
97+
min_code = 0
98+
max_code = codes[-1]
99+
while max_code > 0:
100+
min_code = min_code * 2 + 1
101+
max_code = int((max_code - 1) / 2)
102+
103+
for i in range(len(codes)):
104+
while codes[i] < min_code:
105+
codes[i] = codes[i] * 2 + 1
106+
107+
filter_set = set()
108+
max_level = 0
109+
tree_meta = index_dataset_pb2.TreeMeta()
110+
111+
with open(output_filename, 'wb') as f:
112+
for id, code in zip(ids, codes):
113+
node = index_dataset_pb2.Node()
114+
node.id = id
115+
node.is_leaf = True
116+
node.probability = 1.0
117+
118+
kv_item = index_dataset_pb2.KVItem()
119+
kv_item.key = self._make_key(code)
120+
kv_item.value = node.SerializeToString()
121+
self._write_kv(f, kv_item.SerializeToString())
122+
123+
ancessors = self._ancessors(code)
124+
if len(ancessors) + 1 > max_level:
125+
max_level = len(ancessors) + 1
126+
127+
for ancessor in ancessors:
128+
if ancessor not in filter_set:
129+
node = index_dataset_pb2.Node()
130+
node.id = id_offset + ancessor # id = id_offset + code
131+
node.is_leaf = False
132+
node.probability = 1.0
133+
kv_item = index_dataset_pb2.KVItem()
134+
kv_item.key = self._make_key(ancessor)
135+
kv_item.value = node.SerializeToString()
136+
self._write_kv(f, kv_item.SerializeToString())
137+
filter_set.add(ancessor)
138+
139+
tree_meta.branch = self.branch
140+
tree_meta.height = max_level
141+
kv_item = index_dataset_pb2.KVItem()
142+
kv_item.key = '.tree_meta'
143+
kv_item.value = tree_meta.SerializeToString()
144+
self._write_kv(f, kv_item.SerializeToString())
145+
146+
def _ancessors(self, code):
147+
ancs = []
148+
while code > 0:
149+
code = int((code - 1) / 2)
150+
ancs.append(code)
151+
return ancs
152+
153+
def _make_key(self, code):
154+
return str(code)
155+
156+
def _write_kv(self, fwr, message):
157+
fwr.write(struct.pack('i', len(message)))
158+
fwr.write(message)
159+
160+
161+
if __name__ == '__main__':
162+
parser = argparse.ArgumentParser(description="TreeIndexBuiler")
163+
parser.add_argument(
164+
"--branch", required=False, type=int, default=2, help="tree branch.")
165+
parser.add_argument(
166+
"--mode",
167+
required=True,
168+
choices=['by_category', 'by_kmeans'],
169+
help="mode")
170+
parser.add_argument("--input", required=True, help="input filename")
171+
parser.add_argument("--output", required=True, help="output filename")
172+
173+
args = parser.parse_args()
174+
if args.mode == "by_category":
175+
builder = TreeIndexBuilder(args.branch)
176+
builder.build_by_category(args.input, args.output)
177+
elif args.mode == "by_kmeans":
178+
builder = TreeIndexBuilder(args.branch)
179+
builder.tree_init_by_category(args.input, args.output)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (C) 2016-2018 Alibaba Group Holding Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
# Copyright 2018 Alibaba Inc. All Rights Conserved
17+
18+
import argparse
19+
import random
20+
import time
21+
import numpy as np
22+
23+
24+
class DataCutter(object):
25+
def __init__(self, inp, train, test, number):
26+
self._input = inp
27+
self._train = train
28+
self._test = test
29+
self._number = number
30+
31+
def cut(self):
32+
user_behav = dict()
33+
user_ids = list()
34+
with open(self._input) as f:
35+
for line in f:
36+
arr = line.strip().split(',')
37+
if len(arr) != 5:
38+
break
39+
40+
if arr[0] not in user_behav:
41+
user_ids.append(arr[0])
42+
user_behav[arr[0]] = list()
43+
44+
user_behav[arr[0]].append(line)
45+
46+
random.shuffle(user_ids)
47+
test_user_ids = user_ids[:self._number]
48+
train_user_ids = user_ids[self._number:]
49+
50+
# write train data set
51+
with open(self._train, 'wb') as f:
52+
for uid in train_user_ids:
53+
for line in user_behav[uid]:
54+
f.write(line)
55+
56+
with open(self._test, 'wb') as f:
57+
for uid in test_user_ids:
58+
for line in user_behav[uid]:
59+
f.write(line)
60+
61+
62+
if __name__ == '__main__':
63+
_PARSER = argparse.ArgumentParser(
64+
description="DataCutter, split data into train and test.")
65+
_PARSER.add_argument("--input", required=True, help="input filename")
66+
_PARSER.add_argument(
67+
"--train", required=True, help="filename of output train set")
68+
_PARSER.add_argument(
69+
"--test", required=True, help="filename of output test set")
70+
_PARSER.add_argument(
71+
"--number",
72+
required=True,
73+
type=int,
74+
help="number of users for test set")
75+
_ARGS = _PARSER.parse_args()
76+
77+
DataCutter(_ARGS.input, _ARGS.train, _ARGS.test, _ARGS.number).cut()

0 commit comments

Comments
 (0)