Skip to content

Commit 6870cc0

Browse files
committed
auc for naml
1 parent b93a125 commit 6870cc0

File tree

3 files changed

+44
-10
lines changed

3 files changed

+44
-10
lines changed

models/rank/naml/NAMLDataReader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def init(self):
6363
#line [0]id cate_id sub_cate_id [3]title content
6464
for file in self.article_file_list:
6565
with open(file, "r") as rf:
66+
6667
for l in rf:
6768
line_x = [x.strip() for x in l.split('\t')]
6869
id = line_x[0]

models/rank/naml/dygraph_model.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,23 @@
1616
import paddle.nn as nn
1717
import paddle.nn.functional as F
1818
import math
19+
import numpy as np
1920

2021
import net
2122

2223

2324
class DygraphModel():
25+
def __init__(self):
26+
self.bucket = 1000000
27+
self.absolute_limt = 200.0
28+
29+
def rescale(self, number):
30+
if number > self.absolute_limt:
31+
number = self.absolute_limt
32+
elif number < -self.absolute_limt:
33+
number = -self.absolute_limt
34+
return (number + self.absolute_limt) / (self.absolute_limt * 2 + 1e-8)
35+
2436
# define model
2537
def create_model(self, config):
2638
article_content_size = config.get(
@@ -63,8 +75,10 @@ def create_optimizer(self, dy_model, config):
6375
# define metrics such as auc/acc
6476
# multi-task need to define multi metric
6577
def create_metrics(self):
66-
metrics_list_name = ["acc"]
67-
auc_metric = paddle.metric.Accuracy()
78+
# metrics_list_name = ["acc"]
79+
# auc_metric = paddle.metric.Accuracy()
80+
metrics_list_name = ["auc"]
81+
auc_metric = paddle.metric.Auc(num_thresholds=self.bucket)
6882
metrics_list = [auc_metric]
6983
return metrics_list, metrics_list_name
7084

@@ -77,18 +91,37 @@ def train_forward(self, dy_model, metrics_list, batch_data, config):
7791

7892
loss = paddle.nn.functional.cross_entropy(
7993
input=raw, label=paddle.cast(labels, "float32"), soft_label=True)
80-
correct = metrics_list[0].compute(raw, labels)
81-
metrics_list[0].update(correct)
94+
95+
scaled = raw.numpy()
96+
scaled_pre = []
97+
[rows, cols] = scaled.shape
98+
for i in range(rows):
99+
for j in range(cols):
100+
scaled_pre.append(1.0 - self.rescale(scaled[i, j]))
101+
scaled_pre.append(self.rescale(scaled[i, j]))
102+
scaled_np_predict = np.array(scaled_pre).reshape([-1, 2])
103+
metrics_list[0].update(scaled_np_predict,
104+
paddle.reshape(labels, [-1, 1]))
105+
82106
loss = paddle.mean(loss)
83107
print_dict = None
84108
return loss, metrics_list, print_dict
85109

86110
def infer_forward(self, dy_model, metrics_list, batch_data, config):
87-
label, sparse_tensor, dense_tensor = self.create_feeds(batch_data,
88-
config)
111+
labels, sparse_tensor, dense_tensor = self.create_feeds(batch_data,
112+
config)
89113
raw = dy_model(sparse_tensor, None)
90-
raw = paddle.nn.functional.softmax(raw)
91-
correct = metrics_list[0].compute(raw, label)
92-
metrics_list[0].update(correct)
114+
#predict_raw = paddle.nn.functional.softmax(raw)
115+
116+
scaled = raw.numpy()
117+
scaled_pre = []
118+
[rows, cols] = scaled.shape
119+
for i in range(rows):
120+
for j in range(cols):
121+
scaled_pre.append(1.0 - self.rescale(scaled[i, j]))
122+
scaled_pre.append(self.rescale(scaled[i, j]))
123+
scaled_np_predict = np.array(scaled_pre).reshape([-1, 2])
124+
metrics_list[0].update(scaled_np_predict,
125+
paddle.reshape(labels, [-1, 1]))
93126

94127
return metrics_list, None

models/rank/naml/net.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(self, article_content_size, article_title_size, browse_size,
8484
self.sub_category_size = sub_category_size
8585
self.cate_dimension = cate_dimension
8686
self.word_dict_size = word_dict_size
87-
self.conv_out_channel_size = 100
87+
self.conv_out_channel_size = 400
8888
self.attention_projection_size = 100
8989
self.load_word_embedding()
9090
self.attention_vec = []

0 commit comments

Comments
 (0)