Skip to content

Commit 23ad620

Browse files
committed
fix esmm
1 parent 2743bde commit 23ad620

File tree

6 files changed

+13
-8
lines changed

6 files changed

+13
-8
lines changed

models/multitask/esmm/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ runner:
1818
train_reader_path: "esmm_reader" # importlib format
1919
use_gpu: False
2020
use_auc: True
21+
auc_num: 2
2122
train_batch_size: 2
2223
epochs: 3
2324
print_interval: 2

models/multitask/esmm/config_bigdata.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ runner:
1818
train_reader_path: "esmm_reader" # importlib format
1919
use_gpu: True
2020
use_auc: True
21+
auc_num: 2
2122
train_batch_size: 1024
2223
epochs: 10
2324
print_interval: 10

models/multitask/esmm/readme.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ ESMM是发表在 SIGIR’2018 的论文[《Entire Space Multi-Task Model: An E
7777

7878
### 效果复现
7979
为了方便使用者能够快速的跑通每一个模型,我们在每个模型下都提供了样例数据。如果需要复现readme中的效果,请按如下步骤依次操作即可。
80-
在全量数据下模型的指标如下
80+
在全量数据下模型的训练指标如下
8181
| 模型 | auc_ctcvr | batch_size | epoch_num | Time of each epoch |
8282
| :------| :------ | :------ | :------| :------ |
8383
| ESMM | 0.82 | 1024 | 10 | 约3分钟 |

tools/static_infer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def main(args):
5959

6060
use_gpu = config.get("runner.use_gpu", True)
6161
use_auc = config.get("runner.use_auc", False)
62+
auc_num = config.get("runner.auc_num", 1)
6263
test_data_dir = config.get("runner.test_data_dir", None)
6364
print_interval = config.get("runner.print_interval", None)
6465
model_load_path = config.get("runner.infer_load_path", "model_output")
@@ -92,7 +93,7 @@ def main(args):
9293
epoch_begin = time.time()
9394
interval_begin = time.time()
9495
if use_auc:
95-
reset_auc()
96+
reset_auc(auc_num)
9697
for batch_id, batch_data in enumerate(test_dataloader()):
9798
fetch_batch_var = exe.run(
9899
program=paddle.static.default_main_program(),

tools/static_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def main(args):
6262

6363
use_gpu = config.get("runner.use_gpu", True)
6464
use_auc = config.get("runner.use_auc", False)
65+
auc_num = config.get("runner.auc_num", 1)
6566
train_data_dir = config.get("runner.train_data_dir", None)
6667
epochs = config.get("runner.epochs", None)
6768
print_interval = config.get("runner.print_interval", None)
@@ -93,7 +94,7 @@ def main(args):
9394

9495
epoch_begin = time.time()
9596
if use_auc:
96-
reset_auc()
97+
reset_auc(auc_num)
9798
if reader_type == 'DataLoader':
9899
fetch_batch_var = dataloader_train(epoch_id, train_dataloader,
99100
input_data_names, fetch_vars,

tools/utils/utils_single.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,12 @@ def load_yaml(yaml_file, other_part=None):
130130
return running_config
131131

132132

133-
def reset_auc():
134-
auc_var_name = [
135-
"_generated_var_0", "_generated_var_1", "_generated_var_2",
136-
"_generated_var_3"
137-
]
133+
def reset_auc(auc_num=1):
134+
# for static clear auc
135+
auc_var_name = []
136+
for i in range(auc_num * 4):
137+
auc_var_name.append("_generated_var_%d".format(i))
138+
138139
for name in auc_var_name:
139140
param = paddle.fluid.global_scope().var(name)
140141
if param == None:

0 commit comments

Comments
 (0)