Skip to content

Commit 82eed5f

Browse files
authored
Merge pull request #365 from seemingwang/master
use auc for naml
2 parents b93a125 + bffac75 commit 82eed5f

File tree

5 files changed

+50
-17
lines changed

5 files changed

+50
-17
lines changed

models/rank/naml/NAMLDataReader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def init(self):
6363
#line [0]id cate_id sub_cate_id [3]title content
6464
for file in self.article_file_list:
6565
with open(file, "r") as rf:
66+
6667
for l in rf:
6768
line_x = [x.strip() for x in l.split('\t')]
6869
id = line_x[0]

models/rank/naml/README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ python3 -u ../../../tools/infer.py -m config.yaml
9191
#### Loss及Acc计算
9292
- 预测的结果为一个softmax向量,表示实际浏览文章和负采样文章同时出现的情况下被用户浏览的概率
9393
- 样本的损失函数值由交叉熵给出
94-
- 我们同时还会计算预测的acc,即top1的准确率
94+
- 我们同时还会计算预测的auc
9595

9696
## 效果复现
9797
为了方便使用者能够快速的跑通每一个模型,我们在每个模型下都提供了样例数据。
@@ -111,12 +111,11 @@ python3 -u ../../../tools/trainer.py -m config_bigdata.yaml
111111
以下为训练2个epoch的结果
112112
| 模型 | top1 acc | batch_size | epoch_num| Time of each epoch|
113113
| :------| :------ | :------ | :------| :------ |
114-
| naml | 0.43 | 512 | 4 | 约0.5小时 |
114+
| naml | 0.72 | 50 | 3 | 约4小时 |
115115

116116
预测
117117
```
118118
python3 -u ../../../tools/infer.py -m config_bigdata.yaml
119119
```
120120

121-
期待运行结果如下
122-
INFO - epoch: 1 done, acc: 0.427140, epoch time: 126.27 s
121+
期待预测auc为0.65

models/rank/naml/config_bigdata.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@ runner:
1717
train_reader_path: "NAMLDataReader" # importlib format
1818
use_gpu: False
1919
train_batch_size: 50
20-
epochs: 10
20+
epochs: 3
2121
print_interval: 10
2222
#model_init_path: "output_model/0" # init model
2323
model_save_path: "output_model_all"
2424
infer_batch_size: 10
2525
infer_reader_path: "NAMLDataReader" # importlib format
2626
test_data_dir: "../../../datasets/MIND/data/test"
2727
infer_load_path: "output_model_all"
28-
infer_start_epoch: 0
29-
infer_end_epoch: 10
28+
infer_start_epoch: 3
29+
infer_end_epoch: 4
3030

3131
# hyper parameters of user-defined network
3232
hyper_parameters:

models/rank/naml/dygraph_model.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,23 @@
1616
import paddle.nn as nn
1717
import paddle.nn.functional as F
1818
import math
19+
import numpy as np
1920

2021
import net
2122

2223

2324
class DygraphModel():
25+
def __init__(self):
26+
self.bucket = 100000
27+
self.absolute_limt = 200.0
28+
29+
def rescale(self, number):
30+
if number > self.absolute_limt:
31+
number = self.absolute_limt
32+
elif number < -self.absolute_limt:
33+
number = -self.absolute_limt
34+
return (number + self.absolute_limt) / (self.absolute_limt * 2 + 1e-8)
35+
2436
# define model
2537
def create_model(self, config):
2638
article_content_size = config.get(
@@ -63,8 +75,10 @@ def create_optimizer(self, dy_model, config):
6375
# define metrics such as auc/acc
6476
# multi-task need to define multi metric
6577
def create_metrics(self):
66-
metrics_list_name = ["acc"]
67-
auc_metric = paddle.metric.Accuracy()
78+
# metrics_list_name = ["acc"]
79+
# auc_metric = paddle.metric.Accuracy()
80+
metrics_list_name = ["auc"]
81+
auc_metric = paddle.metric.Auc(num_thresholds=self.bucket)
6882
metrics_list = [auc_metric]
6983
return metrics_list, metrics_list_name
7084

@@ -77,18 +91,37 @@ def train_forward(self, dy_model, metrics_list, batch_data, config):
7791

7892
loss = paddle.nn.functional.cross_entropy(
7993
input=raw, label=paddle.cast(labels, "float32"), soft_label=True)
80-
correct = metrics_list[0].compute(raw, labels)
81-
metrics_list[0].update(correct)
94+
95+
scaled = raw.numpy()
96+
scaled_pre = []
97+
[rows, cols] = scaled.shape
98+
for i in range(rows):
99+
for j in range(cols):
100+
scaled_pre.append(1.0 - self.rescale(scaled[i, j]))
101+
scaled_pre.append(self.rescale(scaled[i, j]))
102+
scaled_np_predict = np.array(scaled_pre).reshape([-1, 2])
103+
metrics_list[0].update(scaled_np_predict,
104+
paddle.reshape(labels, [-1, 1]))
105+
82106
loss = paddle.mean(loss)
83107
print_dict = None
84108
return loss, metrics_list, print_dict
85109

86110
def infer_forward(self, dy_model, metrics_list, batch_data, config):
87-
label, sparse_tensor, dense_tensor = self.create_feeds(batch_data,
88-
config)
111+
labels, sparse_tensor, dense_tensor = self.create_feeds(batch_data,
112+
config)
89113
raw = dy_model(sparse_tensor, None)
90-
raw = paddle.nn.functional.softmax(raw)
91-
correct = metrics_list[0].compute(raw, label)
92-
metrics_list[0].update(correct)
114+
#predict_raw = paddle.nn.functional.softmax(raw)
115+
116+
scaled = raw.numpy()
117+
scaled_pre = []
118+
[rows, cols] = scaled.shape
119+
for i in range(rows):
120+
for j in range(cols):
121+
scaled_pre.append(1.0 - self.rescale(scaled[i, j]))
122+
scaled_pre.append(self.rescale(scaled[i, j]))
123+
scaled_np_predict = np.array(scaled_pre).reshape([-1, 2])
124+
metrics_list[0].update(scaled_np_predict,
125+
paddle.reshape(labels, [-1, 1]))
93126

94127
return metrics_list, None

models/rank/naml/net.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(self, article_content_size, article_title_size, browse_size,
8484
self.sub_category_size = sub_category_size
8585
self.cate_dimension = cate_dimension
8686
self.word_dict_size = word_dict_size
87-
self.conv_out_channel_size = 100
87+
self.conv_out_channel_size = 400
8888
self.attention_projection_size = 100
8989
self.load_word_embedding()
9090
self.attention_vec = []

0 commit comments

Comments
 (0)