Skip to content

Commit 7139b6c

Browse files
authored
Merge pull request #718 from yinhaofeng/dcn
dcn_config
2 parents 6d75c9a + a55f845 commit 7139b6c

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

models/rank/dcn/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ deepAndCross模型的组网本质是一个二分类任务,模型代码参考
106106

107107
| 模型 | auc | batch_size | epoch_num| Time of each epoch |
108108
| :------| :------ | :------ | :------| :------ |
109-
| dcn | 0.777 | 32 | 10 | 约 3 小时 |
109+
| dcn | 0.777 | 512 | 10 | 约 3 小时 |
110110

111111
1. 确认您当前所在目录为PaddleRec/models/rank/dcn
112112
2. 在"criteo data"全量数据目录下,运行数据一键处理脚本,命令如下:
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
runner:
17+
train_data_dir: "../../../datasets/criteo/slot_train_data_full"
18+
train_reader_path: "reader" # importlib format
19+
use_gpu: True
20+
use_auc: True
21+
train_batch_size: 512
22+
epochs: 10
23+
print_interval: 10
24+
#model_init_path: "output_model/0" # init model
25+
model_save_path: "output_model_dcn_all"
26+
test_data_dir: "../../../datasets/criteo/slot_test_data_full"
27+
infer_reader_path: "reader" # importlib format
28+
infer_batch_size: 512
29+
infer_load_path: "output_model_dcn_all"
30+
infer_start_epoch: 0
31+
infer_end_epoch: 10
32+
33+
34+
# hyper parameters of user-defined network
35+
hyper_parameters:
36+
# optimizer config
37+
optimizer:
38+
class: Adam
39+
learning_rate: 0.0001
40+
strategy: async
41+
# user-defined <key, value> pairs
42+
sparse_inputs_slots: 27
43+
sparse_feature_number: 1000001
44+
sparse_feature_dim: 9
45+
dense_input_dim: 13
46+
fc_sizes: [512, 256, 128] #, 32]
47+
distributed_embedding: 0
48+
49+
# sparse_inputs_slots + dense_input_dim
50+
51+
cross_num: 2
52+
l2_reg_cross: 0.00005
53+
dnn_use_bn: False
54+
clip_by_norm: 100.0
55+
is_sparse: False
56+
# cat_feat_num: "{workspace}/data/sample_data/cat_feature_num.txt"

0 commit comments

Comments
 (0)