Skip to content

Commit c587dba

Browse files
committed
update readme && fix codestyle problem
1 parent 0b75aa0 commit c587dba

File tree

3 files changed

+31
-22
lines changed

3 files changed

+31
-22
lines changed

models/recall/mind/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,15 @@ python -u static_infer.py -m config.yaml -top_n 50 #对测试数据进行预测
107107
在全量数据下模型的指标如下:
108108
| 模型 | batch_size | epoch_num| Recall@50 | NDCG@50 | HitRate@50 |Time of each epoch |
109109
| :------| :------ | :------ | :------| :------ | :------| :------ |
110-
| mind | 128 | 20 | 8.43% | 13.28% | 17.22% | 398.64s(CPU) |
111-
110+
| mind(paddle实现) | 128 | 50 | 5.52% | 4.31% | 11.49% | 356.43s(CPU) |
112111

113112
1. 确认您当前所在目录为PaddleRec/models/recall/mind
114113
2. 进入paddlerec/datasets/AmazonBook目录下执行run.sh脚本,会下载处理完成的AmazonBook数据集,并解压到指定目录
115114
```bash
116115
cd ../../../datasets/AmazonBook
117116
sh run.sh
118117
```
118+
119119
3. 安装依赖,我们使用[faiss](https://github.com/facebookresearch/faiss)来进行向量召回
120120
```bash
121121
# CPU-only version(pip)

models/recall/mind/config_bigdata.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ runner:
1818
use_gpu: False
1919
use_auc: False
2020
train_batch_size: 128
21-
epochs: 20
21+
epochs: 50
2222
print_interval: 500
2323
model_save_path: "output_model_mind_all"
2424
infer_batch_size: 128
2525
infer_reader_path: "mind_infer_reader" # importlib format
2626
test_data_dir: "../../../datasets/AmazonBook/valid"
2727
infer_load_path: "output_model_mind_all"
28-
infer_start_epoch: 19
29-
infer_end_epoch: 20
28+
infer_start_epoch: 49
29+
infer_end_epoch: 50
3030

3131
# distribute_config
3232
# sync_mode: "async"

models/recall/mind/net.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
class Mind_SampledSoftmaxLoss_Layer(nn.Layer):
2222
"""SampledSoftmaxLoss with LogUniformSampler
2323
"""
24+
2425
def __init__(self,
2526
num_classes,
2627
n_sample,
@@ -83,13 +84,13 @@ def forward(self, inputs, labels, weights, bias):
8384
sample_b = all_b[-n_sample:]
8485

8586
# [B, D] * [B, 1,D]
86-
true_logist = paddle.sum(paddle.multiply(
87-
true_w, inputs.unsqueeze(1)), axis=-1) + true_b
87+
true_logist = paddle.sum(paddle.multiply(true_w, inputs.unsqueeze(1)),
88+
axis=-1) + true_b
8889
# print(true_logist)
89-
90+
9091
sample_logist = paddle.matmul(
91-
inputs, sample_w, transpose_y=True) + sample_b
92-
92+
inputs, sample_w, transpose_y=True) + sample_b
93+
9394
if self.remove_accidental_hits:
9495
hit = (paddle.equal(labels[:, :], neg_samples))
9596
padding = paddle.ones_like(sample_logist) * -1e30
@@ -115,6 +116,7 @@ def forward(self, inputs, labels, weights, bias):
115116
class Mind_Capsual_Layer(nn.Layer):
116117
"""Mind_Capsual_Layer
117118
"""
119+
118120
def __init__(self,
119121
input_units,
120122
output_units,
@@ -189,11 +191,13 @@ def forward(self, item_his_emb, seq_len):
189191

190192
low_capsule_new_tile = paddle.tile(low_capsule_new, [1, 1, self.k_max])
191193
low_capsule_new_tile = paddle.reshape(
192-
low_capsule_new_tile, [-1, self.maxlen, self.k_max, self.output_units])
193-
low_capsule_new_tile = paddle.transpose(
194-
low_capsule_new_tile, [0, 2, 1, 3])
194+
low_capsule_new_tile,
195+
[-1, self.maxlen, self.k_max, self.output_units])
196+
low_capsule_new_tile = paddle.transpose(low_capsule_new_tile,
197+
[0, 2, 1, 3])
195198
low_capsule_new_tile = paddle.reshape(
196-
low_capsule_new_tile, [-1, self.k_max, self.maxlen, self.output_units])
199+
low_capsule_new_tile,
200+
[-1, self.k_max, self.maxlen, self.output_units])
197201
low_capsule_new_nograd = paddle.assign(low_capsule_new_tile)
198202
low_capsule_new_nograd.stop_gradient = True
199203

@@ -209,8 +213,9 @@ def forward(self, item_his_emb, seq_len):
209213
high_capsule_tmp = paddle.matmul(W, low_capsule_new_nograd)
210214
# print(low_capsule_new_nograd.shape)
211215
high_capsule = self.squash(high_capsule_tmp)
212-
B_delta = paddle.matmul(low_capsule_new_nograd,
213-
paddle.transpose(high_capsule, [0, 1, 3, 2]))
216+
B_delta = paddle.matmul(
217+
low_capsule_new_nograd,
218+
paddle.transpose(high_capsule, [0, 1, 3, 2]))
214219
B_delta = paddle.reshape(
215220
B_delta, shape=[-1, self.k_max, self.maxlen])
216221
B += B_delta
@@ -220,8 +225,8 @@ def forward(self, item_his_emb, seq_len):
220225
W = paddle.unsqueeze(W, axis=2)
221226
interest_capsule = paddle.matmul(W, low_capsule_new_tile)
222227
interest_capsule = self.squash(interest_capsule)
223-
high_capsule = paddle.reshape(
224-
interest_capsule, [-1, self.k_max, self.output_units])
228+
high_capsule = paddle.reshape(interest_capsule,
229+
[-1, self.k_max, self.output_units])
225230

226231
high_capsule = F.relu(self.relu_layer(high_capsule))
227232
return high_capsule, W, seq_len
@@ -277,12 +282,16 @@ def __init__(self,
277282
def label_aware_attention(self, keys, query):
278283
"""label_aware_attention
279284
"""
280-
weight = paddle.matmul(keys, paddle.reshape(query, [-1, paddle.shape(query)[-1], 1])) #[B, K, dim] * [B, dim, 1] == [B, k, 1]
285+
weight = paddle.matmul(keys,
286+
paddle.reshape(query, [
287+
-1, paddle.shape(query)[-1], 1
288+
])) #[B, K, dim] * [B, dim, 1] == [B, k, 1]
281289
weight = paddle.squeeze(weight, axis=-1)
282290
weight = paddle.pow(weight, self.pow_p) # [x,k_max]
283-
weight = F.softmax(weight) #[x, k_max]
284-
weight = paddle.unsqueeze(weight, 1) #[B, 1, k_max]
285-
output = paddle.matmul(weight, keys) #[B, 1, k_max] * [B, k_max, dim] => [B, 1, dim]
291+
weight = F.softmax(weight) #[x, k_max]
292+
weight = paddle.unsqueeze(weight, 1) #[B, 1, k_max]
293+
output = paddle.matmul(
294+
weight, keys) #[B, 1, k_max] * [B, k_max, dim] => [B, 1, dim]
286295
return output.squeeze(1), weight
287296

288297
def forward(self, hist_item, seqlen, labels=None):

0 commit comments

Comments
 (0)