Skip to content

Commit 50e5368

Browse files
authored
Merge pull request #405 from duyiqi17/mind
update reademe and fix some configs
2 parents 11722f2 + f0d6357 commit 50e5368

File tree

4 files changed

+32
-10
lines changed

4 files changed

+32
-10
lines changed

models/recall/mind/README.md

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ os : windows/linux/macos
7676

7777
在mind模型目录的快速执行命令如下:
7878
```
79+
# 安装faiss
80+
# CPU
81+
pip install faiss-cpu
82+
# GPU
83+
# pip install faiss-gpu
84+
7985
# 进入模型目录
8086
# cd models/recall/mind # 在任意目录均可运行
8187
# 动态图训练
@@ -99,7 +105,7 @@ python -u static_infer.py -m config.yaml -top_n 50 #对测试数据进行预测
99105
在全量数据下模型的指标如下:
100106
| 模型 | batch_size | epoch_num| Recall@50 | NDCG@50 | HitRate@50 |Time of each epoch |
101107
| :------| :------ | :------ | :------| :------ | :------| :------ |
102-
| mind | 128 | 20 | 8.43% | 13.28% | 17.22% | -- |
108+
| mind | 128 | 20 | 8.43% | 13.28% | 17.22% | 398.64s(CPU) |
103109

104110

105111
1. 确认您当前所在目录为PaddleRec/models/recall/mind
@@ -108,12 +114,26 @@ python -u static_infer.py -m config.yaml -top_n 50 #对测试数据进行预测
108114
cd ../../../datasets/AmazonBook
109115
sh run.sh
110116
```
111-
3. 切回模型目录,执行命令运行全量数据
117+
3. 安装依赖,我们使用[faiss](https://github.com/facebookresearch/faiss)来进行向量召回
118+
```bash
119+
# CPU-only version(pip)
120+
pip install faiss-cpu
121+
122+
# GPU(+CPU) version(pip)
123+
#pip install faiss-gpu
124+
125+
# CPU-only version(conda)
126+
#conda install -c pytorch faiss-cpu
127+
128+
# GPU(+CPU) version(conda)
129+
#conda install -c pytorch faiss-gpu
130+
```
131+
4. 切回模型目录,执行命令运行全量数据
112132
```bash
113133
cd - # 切回模型目录
114134
# 动态图训练
115135
python -u ../../../tools/trainer.py -m config_bigdata.yaml # 全量数据运行config_bigdata
116-
python -u infer.py -m config_bigdata.yaml # 全量数据运行config_bigdata
136+
python -u infer.py -m config_bigdata.yaml -top_n 50 # 全量数据运行config_bigdata
117137
```
118138

119139
## 进阶使用

models/recall/mind/config.yaml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,19 @@
1515
runner:
1616
train_data_dir: "data/train"
1717
train_reader_path: "mind_reader" # importlib format
18-
use_gpu: True
18+
use_gpu: False
1919
use_auc: False
2020
train_batch_size: 128
21-
epochs: 2
22-
print_interval: 500
21+
epochs: 1
22+
print_interval: 10
2323
model_save_path: "output_model_mind"
2424
infer_batch_size: 128
2525
infer_reader_path: "mind_infer_reader" # importlib format
2626
test_data_dir: "data/valid"
2727
infer_load_path: "output_model_mind"
2828
infer_start_epoch: 0
2929
infer_end_epoch: 1
30+
batches_per_epoch: 100
3031

3132
# distribute_config
3233
# sync_mode: "async"
@@ -45,6 +46,6 @@ hyper_parameters:
4546
item_count: 367983
4647
embedding_dim: 64
4748
hidden_size: 64
48-
neg_samples: 1280
49+
neg_samples: 128
4950
maxlen: 20
5051
pow_p: 1.0

models/recall/mind/config_bigdata.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
runner:
1616
train_data_dir: "../../../datasets/AmazonBook/train"
1717
train_reader_path: "mind_reader" # importlib format
18-
use_gpu: True
18+
use_gpu: False
1919
use_auc: False
2020
train_batch_size: 128
21-
epochs: 6
21+
epochs: 20
2222
print_interval: 500
2323
model_save_path: "output_model_mind"
2424
infer_batch_size: 128

models/recall/mind/mind_reader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(self, file_list, config):
2424
self.file_list = file_list
2525
self.maxlen = config.get("hyper_parameters.maxlen", 30)
2626
self.batch_size = config.get("runner.train_batch_size", 128)
27+
self.batches_per_epoch = config.get("runner.batches_per_epoch", 1000)
2728
self.init()
2829
self.count = 0
2930

@@ -52,7 +53,7 @@ def init(self):
5253
def __iter__(self):
5354
while True:
5455
user_id_list = random.sample(self.users, self.batch_size)
55-
if self.count >= 1000 * self.batch_size:
56+
if self.count >= self.batches_per_epoch * self.batch_size:
5657
self.count = 0
5758
break
5859
for user_id in user_id_list:

0 commit comments

Comments
 (0)