Skip to content

Commit e9adfc4

Browse files
committed
Add distributed unit tests about text_classification/simnet-bow/ctr (#12812)
* add dist ut for text_classification * add dist ut for text_classification * add simnet bow unittest * add dist ut for simnet bow * add trainning data url for simnet bow * add trainning data url for simnet bow * modify simnet test_reader to train reader * add test_dist_ctr * test_dist_ctr can run now * dense update is good * add unit test for selected rows * debug unit test * fix dist sparse update problem * Constant args at init * optimize code * simnet optimize * fix DebugStringEx * optimize sum_op.h * add ScaleOpVarTypeInference * clean code * fix test_dist_transpiler.py * code optimize * modify delta * fix sparse update bug * dist test use one cpu * update some data * remove unused code * add use cuda config * unit test fix * unit test fix * unit test fix * unit test fix * dist_word2vec use CPU * unit test fix * unit test fix * code clean * code clean * merge develop * api spec update * Revert: api spec update * replace simnet data with fake * replace simnet data with fake * update dim * add batch auc * code clean * code clean * modify print to stderr * update simnet delta -> 1e-5 * update RUN_STEP * add use_reader_alloc * add use_reader_alloc * add use_reader_alloc * modify delta * add use_reader_alloc * fix stderr write * python3 compatibility test=develop * python3 compatibility, test=develop * Update dist_text_classification.py * test=develop
1 parent 3f7a7ea commit e9adfc4

19 files changed

+1102
-117
lines changed

paddle/fluid/framework/selected_rows_test.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,11 @@ class SelectedRowsTester : public ::testing::Test {
2727
selected_rows_.reset(new SelectedRows(rows, height));
2828

2929
Tensor* value = selected_rows_->mutable_value();
30-
value->mutable_data<float>(
30+
auto* data = value->mutable_data<float>(
3131
make_ddim({static_cast<int64_t>(rows.size()), row_numel}), place_);
32+
for (int64_t i = 0; i < value->numel(); ++i) {
33+
data[i] = static_cast<float>(i);
34+
}
3235
}
3336

3437
protected:
@@ -60,6 +63,10 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) {
6063
ASSERT_EQ(selected_rows_->height(), dst_tensor.height());
6164
ASSERT_EQ(selected_rows_->value().dims(), dst_tensor.value().dims());
6265
ASSERT_EQ(selected_rows_->GetCompleteDims(), dst_tensor.GetCompleteDims());
66+
auto* dst_data = dst_tensor.value().data<float>();
67+
for (int64_t i = 0; i < dst_tensor.value().numel(); ++i) {
68+
ASSERT_EQ(dst_data[i], static_cast<float>(i));
69+
}
6370
}
6471

6572
TEST(SelectedRows, SparseTable) {

paddle/fluid/operators/scale_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,10 @@ class ScaleOpVarTypeInference : public framework::VarTypeInference {
7777
auto out_var_name = op_desc.Output("Out").front();
7878
auto *out_var = block->FindVarRecursive(out_var_name);
7979

80-
out_var->SetType(in_var.GetType());
81-
out_var->SetDataType(in_var.GetDataType());
80+
if (in_var_name != out_var_name) {
81+
out_var->SetType(in_var.GetType());
82+
out_var->SetDataType(in_var.GetDataType());
83+
}
8284
}
8385
};
8486

paddle/fluid/operators/sum_op.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class SumKernel : public framework::OpKernel<T> {
3232
public:
3333
void Compute(const framework::ExecutionContext &context) const override {
3434
auto in_vars = context.MultiInputVar("X");
35-
int N = in_vars.size();
35+
size_t in_num = in_vars.size();
3636
auto out_var = context.OutputVar("Out");
3737

3838
bool in_place = out_var == in_vars[0];
@@ -53,7 +53,7 @@ class SumKernel : public framework::OpKernel<T> {
5353
auto &place =
5454
*context.template device_context<DeviceContext>().eigen_device();
5555
// If in_place, just skip the first tensor
56-
for (int i = in_place ? 1 : 0; i < N; i++) {
56+
for (size_t i = in_place ? 1 : 0; i < in_num; i++) {
5757
if (in_vars[i]->IsType<framework::LoDTensor>()) {
5858
auto &in_t = in_vars[i]->Get<framework::LoDTensor>();
5959
if (in_t.numel() == 0) {
@@ -101,13 +101,13 @@ class SumKernel : public framework::OpKernel<T> {
101101

102102
// Runtime InferShape
103103
size_t first_dim = 0;
104-
for (int i = 0; i < N; i++) {
104+
for (size_t i = 0; i < in_num; i++) {
105105
auto &sel_row = get_selected_row(i);
106106
first_dim += sel_row.rows().size();
107107
}
108108

109109
std::vector<int64_t> in_dim;
110-
for (int i = 0; i < N; i++) {
110+
for (size_t i = 0; i < in_num; i++) {
111111
auto &sel_row = get_selected_row(i);
112112
if (sel_row.rows().size() > 0) {
113113
in_dim = framework::vectorize(sel_row.value().dims());
@@ -116,7 +116,8 @@ class SumKernel : public framework::OpKernel<T> {
116116
}
117117
if (in_dim.empty()) {
118118
VLOG(3) << "WARNING: all the inputs are empty";
119-
in_dim = framework::vectorize(get_selected_row(N - 1).value().dims());
119+
in_dim =
120+
framework::vectorize(get_selected_row(in_num - 1).value().dims());
120121
} else {
121122
in_dim[0] = static_cast<int64_t>(first_dim);
122123
}
@@ -133,7 +134,7 @@ class SumKernel : public framework::OpKernel<T> {
133134
math::SelectedRowsAddTo<DeviceContext, T> functor;
134135

135136
int64_t offset = 0;
136-
for (int i = 0; i < N; i++) {
137+
for (size_t i = 0; i < in_num; i++) {
137138
auto &sel_row = get_selected_row(i);
138139
if (sel_row.rows().size() == 0) {
139140
continue;

python/paddle/dataset/common.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,14 @@ def download(url, module_name, md5sum, save_name=None):
7777
retry_limit = 3
7878
while not (os.path.exists(filename) and md5file(filename) == md5sum):
7979
if os.path.exists(filename):
80-
print("file md5", md5file(filename), md5sum)
80+
sys.stderr.write("file %s md5 %s" % (md5file(filename), md5sum))
8181
if retry < retry_limit:
8282
retry += 1
8383
else:
8484
raise RuntimeError("Cannot download {0} within retry limit {1}".
8585
format(url, retry_limit))
86-
print("Cache file %s not found, downloading %s" % (filename, url))
86+
sys.stderr.write("Cache file %s not found, downloading %s" %
87+
(filename, url))
8788
r = requests.get(url, stream=True)
8889
total_length = r.headers.get('content-length')
8990

@@ -100,10 +101,11 @@ def download(url, module_name, md5sum, save_name=None):
100101
dl += len(data)
101102
f.write(data)
102103
done = int(50 * dl / total_length)
103-
sys.stdout.write("\r[%s%s]" % ('=' * done,
104+
sys.stderr.write("\r[%s%s]" % ('=' * done,
104105
' ' * (50 - done)))
105106
sys.stdout.flush()
106-
107+
sys.stderr.write("\n")
108+
sys.stdout.flush()
107109
return filename
108110

109111

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import paddle
18+
import paddle.fluid as fluid
19+
20+
import dist_ctr_reader
21+
from test_dist_base import TestDistRunnerBase, runtime_main
22+
23+
IS_SPARSE = True
24+
25+
# Fix seed for test
26+
fluid.default_startup_program().random_seed = 1
27+
fluid.default_main_program().random_seed = 1
28+
29+
30+
class TestDistCTR2x2(TestDistRunnerBase):
31+
def get_model(self, batch_size=2):
32+
dnn_input_dim, lr_input_dim = dist_ctr_reader.load_data_meta()
33+
""" network definition """
34+
dnn_data = fluid.layers.data(
35+
name="dnn_data",
36+
shape=[-1, 1],
37+
dtype="int64",
38+
lod_level=1,
39+
append_batch_size=False)
40+
lr_data = fluid.layers.data(
41+
name="lr_data",
42+
shape=[-1, 1],
43+
dtype="int64",
44+
lod_level=1,
45+
append_batch_size=False)
46+
label = fluid.layers.data(
47+
name="click",
48+
shape=[-1, 1],
49+
dtype="int64",
50+
lod_level=0,
51+
append_batch_size=False)
52+
53+
# build dnn model
54+
dnn_layer_dims = [128, 64, 32, 1]
55+
dnn_embedding = fluid.layers.embedding(
56+
is_distributed=False,
57+
input=dnn_data,
58+
size=[dnn_input_dim, dnn_layer_dims[0]],
59+
param_attr=fluid.ParamAttr(
60+
name="deep_embedding",
61+
initializer=fluid.initializer.Constant(value=0.01)),
62+
is_sparse=IS_SPARSE)
63+
dnn_pool = fluid.layers.sequence_pool(
64+
input=dnn_embedding, pool_type="sum")
65+
dnn_out = dnn_pool
66+
for i, dim in enumerate(dnn_layer_dims[1:]):
67+
fc = fluid.layers.fc(
68+
input=dnn_out,
69+
size=dim,
70+
act="relu",
71+
param_attr=fluid.ParamAttr(
72+
initializer=fluid.initializer.Constant(value=0.01)),
73+
name='dnn-fc-%d' % i)
74+
dnn_out = fc
75+
76+
# build lr model
77+
lr_embbding = fluid.layers.embedding(
78+
is_distributed=False,
79+
input=lr_data,
80+
size=[lr_input_dim, 1],
81+
param_attr=fluid.ParamAttr(
82+
name="wide_embedding",
83+
initializer=fluid.initializer.Constant(value=0.01)),
84+
is_sparse=IS_SPARSE)
85+
lr_pool = fluid.layers.sequence_pool(input=lr_embbding, pool_type="sum")
86+
87+
merge_layer = fluid.layers.concat(input=[dnn_out, lr_pool], axis=1)
88+
89+
predict = fluid.layers.fc(input=merge_layer, size=2, act='softmax')
90+
acc = fluid.layers.accuracy(input=predict, label=label)
91+
auc_var, batch_auc_var, auc_states = fluid.layers.auc(input=predict,
92+
label=label)
93+
cost = fluid.layers.cross_entropy(input=predict, label=label)
94+
avg_cost = fluid.layers.mean(x=cost)
95+
96+
inference_program = paddle.fluid.default_main_program().clone()
97+
98+
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.0001)
99+
sgd_optimizer.minimize(avg_cost)
100+
101+
dataset = dist_ctr_reader.Dataset()
102+
train_reader = paddle.batch(dataset.train(), batch_size=batch_size)
103+
test_reader = paddle.batch(dataset.test(), batch_size=batch_size)
104+
105+
return inference_program, avg_cost, train_reader, test_reader, None, predict
106+
107+
108+
if __name__ == "__main__":
109+
runtime_main(TestDistCTR2x2)

0 commit comments

Comments
 (0)