Skip to content

Commit 73c2277

Browse files
authored
Merge pull request #756 from renmada/aitm
Add aitm model
2 parents 8623ced + fa1dd3d commit 73c2277

File tree

20 files changed

+1285
-96
lines changed

20 files changed

+1285
-96
lines changed

README_CN.md

Lines changed: 60 additions & 59 deletions
Large diffs are not rendered by default.

README_EN.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ python -u tools/static_trainer.py -m models/rank/dnn/config.yaml # Training wit
159159
| Rank | [FLEN](models/rank/flen/) | - ||| >=2.1.0 | [2019][FLEN: Leveraging Field for Scalable CTR Prediction]( https://arxiv.org/pdf/1911.04690.pdf) |
160160
| Rank | [DeepRec](models/rank/deeprec/) | - ||| >=2.1.0 | [2017][Training Deep AutoEncoders for Collaborative Filtering](https://arxiv.org/pdf/1708.01715v3.pdf) |
161161
| Rank | [AutoFIS](models/rank/autofis/) | - ||| >=2.1.0 | [KDD 2020][AutoFIS: Automatic Feature Interaction Selection in Factorization Models for Click-Through Rate Prediction](https://arxiv.org/pdf/2003.11235v3.pdf) |
162-
| Rank | [DCN_V2](models/rank/dcn_v2/) | - | ✓ | ✓ | >=2.1.0 | [WWW 2021][DCN V2: Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/pdf/2008.13535v2.pdf)
162+
| Rank | [DCN_V2](models/rank/dcn_v2/) | - ||| >=2.1.0 | [WWW 2021][DCN V2: Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/pdf/2008.13535v2.pdf)|
163+
| Rank | [AITM](models/rank/aitm/) | - ||| >=2.1.0 | [KDD 2021][Modeling the Sequential Dependence among Audience Multi-step Conversions withMulti-task Learning in Targeted Display Advertising](https://arxiv.org/pdf/2105.08489v2.pdf) |
163164
| Multi-Task | [PLE](models/multitask/ple/)<br>([doc](https://paddlerec.readthedocs.io/en/latest/models/multitask/ple.html)) | [Python CPU/GPU](https://aistudio.baidu.com/aistudio/projectdetail/3238938) ||| >=2.1.0 | [RecSys 2020][Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations](https://dl.acm.org/doi/abs/10.1145/3383313.3412236) |
164165
| Multi-Task | [ESMM](models/multitask/esmm/)<br>([doc](https://paddlerec.readthedocs.io/en/latest/models/multitask/esmm.html)) | [Python CPU/GPU](https://aistudio.baidu.com/aistudio/projectdetail/3238583) ||| >=2.1.0 | [SIGIR 2018][Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate](https://arxiv.org/abs/1804.07931) |
165166
| Multi-Task | [MMOE](models/multitask/mmoe/)<br>([doc](https://paddlerec.readthedocs.io/en/latest/models/multitask/mmoe.html)) | [Python CPU/GPU](https://aistudio.baidu.com/aistudio/projectdetail/3238934) ||| >=2.1.0 | [KDD 2018][Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts](https://dl.acm.org/doi/abs/10.1145/3219819.3220007) |
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
'''
15+
process the Ali-CCP (Alibaba Click and Conversion Prediction) dataset.
16+
https://tianchi.aliyun.com/datalab/dataSet.html?dataId=408
17+
18+
@The author:
19+
Dongbo Xi ([email protected])
20+
'''
21+
import numpy as np
22+
import joblib
23+
import re
24+
import random
25+
random.seed(2020)
26+
np.random.seed(2020)
27+
data_path = 'data/sample_skeleton_{}.csv'
28+
common_feat_path = 'data/common_features_{}.csv'
29+
enum_path = 'data/ctrcvr_enum.pkl'
30+
write_path = 'data/ctr_cvr'
31+
use_columns = [
32+
'101', '121', '122', '124', '125', '126', '127', '128', '129', '205',
33+
'206', '207', '216', '508', '509', '702', '853', '301'
34+
]
35+
36+
37+
class process(object):
38+
def __init__(self):
39+
pass
40+
41+
def process_train(self):
42+
c = 0
43+
common_feat_dict = {}
44+
with open(common_feat_path.format('train'), 'r') as fr:
45+
for line in fr:
46+
line_list = line.strip().split(',')
47+
kv = np.array(re.split('\x01|\x02|\x03', line_list[2]))
48+
key = kv[range(0, len(kv), 3)]
49+
value = kv[range(1, len(kv), 3)]
50+
feat_dict = dict(zip(key, value))
51+
common_feat_dict[line_list[0]] = feat_dict
52+
c += 1
53+
if c % 100000 == 0:
54+
print(c)
55+
print('join feats...')
56+
c = 0
57+
vocabulary = dict(
58+
zip(use_columns, [{} for _ in range(len(use_columns))]))
59+
with open(data_path.format('train') + '.tmp', 'w') as fw:
60+
fw.write('click,purchase,' + ','.join(use_columns) + '\n')
61+
with open(data_path.format('train'), 'r') as fr:
62+
for line in fr:
63+
line_list = line.strip().split(',')
64+
if line_list[1] == '0' and line_list[2] == '1':
65+
continue
66+
kv = np.array(re.split('\x01|\x02|\x03', line_list[5]))
67+
key = kv[range(0, len(kv), 3)]
68+
value = kv[range(1, len(kv), 3)]
69+
feat_dict = dict(zip(key, value))
70+
feat_dict.update(common_feat_dict[line_list[3]])
71+
feats = line_list[1:3]
72+
for k in use_columns:
73+
feats.append(feat_dict.get(k, '0'))
74+
fw.write(','.join(feats) + '\n')
75+
for k, v in feat_dict.items():
76+
if k in use_columns:
77+
if v in vocabulary[k]:
78+
vocabulary[k][v] += 1
79+
else:
80+
vocabulary[k][v] = 0
81+
c += 1
82+
if c % 100000 == 0:
83+
print(c)
84+
print('before filter low freq:')
85+
for k, v in vocabulary.items():
86+
print(k + ':' + str(len(v)))
87+
new_vocabulary = dict(
88+
zip(use_columns, [set() for _ in range(len(use_columns))]))
89+
for k, v in vocabulary.items():
90+
for k1, v1 in v.items():
91+
if v1 > 10:
92+
new_vocabulary[k].add(k1)
93+
vocabulary = new_vocabulary
94+
print('after filter low freq:')
95+
for k, v in vocabulary.items():
96+
print(k + ':' + str(len(v)))
97+
joblib.dump(vocabulary, enum_path, compress=3)
98+
99+
print('encode feats...')
100+
vocabulary = joblib.load(enum_path)
101+
feat_map = {}
102+
for feat in use_columns:
103+
feat_map[feat] = dict(
104+
zip(vocabulary[feat], range(1, len(vocabulary[feat]) + 1)))
105+
c = 0
106+
with open(write_path + '.train', 'w') as fw1:
107+
with open(write_path + '.dev', 'w') as fw2:
108+
fw1.write('click,purchase,' + ','.join(use_columns) + '\n')
109+
fw2.write('click,purchase,' + ','.join(use_columns) + '\n')
110+
with open(data_path.format('train') + '.tmp', 'r') as fr:
111+
fr.readline() # remove header
112+
for line in fr:
113+
line_list = line.strip().split(',')
114+
new_line = line_list[:2]
115+
for value, feat in zip(line_list[2:], use_columns):
116+
new_line.append(
117+
str(feat_map[feat].get(value, '0')))
118+
if random.random() >= 0.9:
119+
fw2.write(','.join(new_line) + '\n')
120+
else:
121+
fw1.write(','.join(new_line) + '\n')
122+
c += 1
123+
if c % 100000 == 0:
124+
print(c)
125+
126+
def process_test(self):
127+
c = 0
128+
common_feat_dict = {}
129+
with open(common_feat_path.format('test'), 'r') as fr:
130+
for line in fr:
131+
line_list = line.strip().split(',')
132+
kv = np.array(re.split('\x01|\x02|\x03', line_list[2]))
133+
key = kv[range(0, len(kv), 3)]
134+
value = kv[range(1, len(kv), 3)]
135+
feat_dict = dict(zip(key, value))
136+
common_feat_dict[line_list[0]] = feat_dict
137+
c += 1
138+
if c % 100000 == 0:
139+
print(c)
140+
print('join feats...')
141+
c = 0
142+
with open(data_path.format('test') + '.tmp', 'w') as fw:
143+
fw.write('click,purchase,' + ','.join(use_columns) + '\n')
144+
with open(data_path.format('test'), 'r') as fr:
145+
for line in fr:
146+
line_list = line.strip().split(',')
147+
if line_list[1] == '0' and line_list[2] == '1':
148+
continue
149+
kv = np.array(re.split('\x01|\x02|\x03', line_list[5]))
150+
key = kv[range(0, len(kv), 3)]
151+
value = kv[range(1, len(kv), 3)]
152+
feat_dict = dict(zip(key, value))
153+
feat_dict.update(common_feat_dict[line_list[3]])
154+
feats = line_list[1:3]
155+
for k in use_columns:
156+
feats.append(str(feat_dict.get(k, '0')))
157+
fw.write(','.join(feats) + '\n')
158+
c += 1
159+
if c % 100000 == 0:
160+
print(c)
161+
162+
print('encode feats...')
163+
vocabulary = joblib.load(enum_path)
164+
feat_map = {}
165+
for feat in use_columns:
166+
feat_map[feat] = dict(
167+
zip(vocabulary[feat], range(1, len(vocabulary[feat]) + 1)))
168+
c = 0
169+
with open(write_path + '.test', 'w') as fw:
170+
fw.write('click,purchase,' + ','.join(use_columns) + '\n')
171+
with open(data_path.format('test') + '.tmp', 'r') as fr:
172+
fr.readline() # remove header
173+
for line in fr:
174+
line_list = line.strip().split(',')
175+
new_line = line_list[:2]
176+
for value, feat in zip(line_list[2:], use_columns):
177+
new_line.append(str(feat_map[feat].get(value, '0')))
178+
fw.write(','.join(new_line) + '\n')
179+
c += 1
180+
if c % 100000 == 0:
181+
print(c)
182+
183+
184+
if __name__ == "__main__":
185+
pros = process()
186+
pros.process_train()
187+
pros.process_test()

datasets/ali-cpp_aitm/run.sh

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
mkdir data
2+
mkdir data/whole_data && mkdir data/whole_data/train && mkdir data/whole_data/test
3+
train_source_path="./data/sample_train.tar.gz"
4+
train_target_path="train_data"
5+
test_source_path="./data/sample_test.tar.gz"
6+
test_target_path="test_data"
7+
cd data
8+
echo "downloading sample_train.tar.gz......"
9+
curl -# 'http://jupter-oss.oss-cn-hangzhou.aliyuncs.com/file/opensearch/documents/408/sample_train.tar.gz?Expires=1586435769&OSSAccessKeyId=LTAIGx40tjZWxj6q&Signature=ahUDqhvKT1cGjC4%2FIER2EWtq7o4%3D&response-content-disposition=attachment%3B%20' -H 'Proxy-Connection: keep-alive' -H 'Upgrade-Insecure-Requests: 1' -H 'User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/80.0.3987.163 Safari/537.36' -H 'Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9' -H 'Accept-Language: zh-CN,zh;q=0.9' --compressed --insecure -o sample_train.tar.gz
10+
cd ..
11+
echo "unzipping sample_train.tar.gz......"
12+
tar -xzvf ${train_source_path} -C data && rm -rf ${train_source_path}
13+
cd data
14+
echo "downloading sample_test.tar.gz......"
15+
curl -# 'http://jupter-oss.oss-cn-hangzhou.aliyuncs.com/file/opensearch/documents/408/sample_test.tar.gz?Expires=1586435821&OSSAccessKeyId=LTAIGx40tjZWxj6q&Signature=OwLMPjt1agByQtRVi8pazsAliNk%3D&response-content-disposition=attachment%3B%20' -H 'Proxy-Connection: keep-alive' -H 'Upgrade-Insecure-Requests: 1' -H 'User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/80.0.3987.163 Safari/537.36' -H 'Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9' -H 'Accept-Language: zh-CN,zh;q=0.9' --compressed --insecure -o sample_test.tar.gz
16+
cd ..
17+
echo "unzipping sample_test.tar.gz......"
18+
tar -xzvf ${test_source_path} -C data && rm -rf ${test_source_path}
19+
echo "preprocessing data......"
20+
python process_public_data.py
21+
mv data/ctr_cvr.train data/whole_data/train
22+
mv data/ctr_cvr.test data/whole_data/test

doc/source/models/rank/aitm.md

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# AITM模型的点击率预估模型
2+
3+
代码请参考:[AITM](https://github.com/PaddlePaddle/PaddleRec/tree/master/models/rank/aitm)
4+
如果我们的代码对您有用,还请点个star啊~
5+
6+
## 内容
7+
8+
- [模型简介](#模型简介)
9+
- [数据准备](#数据准备)
10+
- [运行环境](#运行环境)
11+
- [快速开始](#快速开始)
12+
- [效果复现](#效果复现)
13+
- [进阶使用](#进阶使用)
14+
- [FAQ](#FAQ)
15+
16+
## 模型简介
17+
在推荐场景里,用户的转化链路往往有多个中间步骤(曝光->点击->转化),而有些行业转化链路很长,如金融-信用卡业务,它包括曝光->点击->表单(application)->信用核准(approval)->信用卡激活(activation)。处于链路后端的节点(如approval/activation),因为转化时间久,获取难度较大,导致转化数据少,训练时类别不平衡的问题很严重。
18+
19+
作者设计了一种多任务模型框架,充分利用了链路上各个节点的样本,提升模型对后端节点转化率的预估
20+
## 数据准备
21+
22+
数据为[Ali-CCP click](https://tianchi.aliyun.com/datalab/dataSet.html?dataId=408)
23+
在模型目录的data目录下为您准备了快速运行的示例数据,若需要使用全量数据可以参考下方[效果复现](#效果复现)部分。
24+
25+
## 运行环境
26+
PaddlePaddle>=2.0
27+
28+
python 2.7/3.5/3.6/3.7
29+
30+
os : windows/linux/macos
31+
32+
## 快速开始
33+
本文提供了样例数据可以供您快速体验,在任意目录下均可执行。在aitm模型目录的快速执行命令如下:
34+
```bash
35+
# 进入模型目录
36+
# cd models/rank/aitm # 在任意目录均可运行
37+
# 动态图训练
38+
python -u ../../../tools/trainer.py -m config.yaml
39+
40+
# 动态图预测
41+
python -u ../../../tools/infer.py -m config.yaml
42+
```
43+
## 效果复现
44+
为了方便使用者能够快速的跑通每一个模型,我们在每个模型下都提供了样例数据。如果需要复现readme中的效果,请按如下步骤依次操作即可。
45+
在全量数据下模型的指标如下:
46+
| 模型 | click auc | purchase auc |batch_size | epoch_num| Time of each epoch |
47+
| :------| :------ | :------ | :------ | :------| :------ |
48+
| aitm | 0.6186 |0.6525 | 2000 | 6| 约3小时 |
49+
50+
1. 确认您当前所在目录为PaddleRec/models/rank/aitm
51+
2. 进入Paddlerec/datasets/ali-cpp_aitm
52+
3. 执行命令运行全量数据
53+
54+
``` bash
55+
cd ../../../datasets/ali-cpp_aitm
56+
sh run.sh
57+
```
58+
```bash
59+
cd - # 切回模型目录
60+
# 动态图训练
61+
python -u ../../../tools/trainer.py -m config_bigdata.yaml
62+
python -u ../../../tools/infer.py -m config_bigdata.yaml
63+
```
64+
## 进阶使用
65+
66+
## FAQ

doc/source/readme.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,4 @@
4848
[fat_deepffm](https://paddlerec.readthedocs.io/en/latest/models/rank/fat_deepffm.html)
4949
[deeprec](https://paddlerec.readthedocs.io/en/latest/models/rank/deeprec.html)
5050
[autofis](https://paddlerec.readthedocs.io/en/latest/models/rank/autofis.html)
51+
[aitm](https://paddlerec.readthedocs.io/en/latest/models/rank/aitm.html)

models/rank/aitm/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

models/rank/aitm/aitm_reader.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
import numpy as np
17+
import paddle
18+
from paddle.io import Dataset
19+
20+
21+
class RecDataset(Dataset):
22+
def __init__(self, file_list, config):
23+
super(RecDataset, self).__init__()
24+
self.feature_names = []
25+
self.datafile = file_list[0]
26+
self.data = []
27+
self._load_data()
28+
29+
def _load_data(self):
30+
print("start load data from: {}".format(self.datafile))
31+
count = 0
32+
with open(self.datafile) as f:
33+
self.feature_names = f.readline().strip().split(',')[2:]
34+
for line in f:
35+
count += 1
36+
line = line.strip().split(',')
37+
line = [int(v) for v in line]
38+
self.data.append(line)
39+
print("load data from {} finished".format(self.datafile))
40+
41+
def __len__(self, ):
42+
return len(self.data)
43+
44+
def __getitem__(self, idx):
45+
line = self.data[idx]
46+
click = line[0]
47+
conversion = line[1]
48+
# features = dict(zip(self.feature_names, line[2:]))
49+
features = line[2:]
50+
return click, conversion, features

0 commit comments

Comments
 (0)