Skip to content

Commit a05fce6

Browse files
committed
Merge remote-tracking branch 'ups/develop' into fix/jit/avx
test=develop
2 parents d24d282 + d0fdcb2 commit a05fce6

File tree

6 files changed

+174
-27
lines changed

6 files changed

+174
-27
lines changed

paddle/fluid/operators/math/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selec
6868
cc_test(im2col_test SRCS im2col_test.cc DEPS im2col)
6969
cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col)
7070
cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding)
71+
cc_test(sequence_pooling_test SRCS sequence_pooling_test.cc DEPS sequence_pooling)
7172
if(WITH_GPU)
7273
nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function)
7374
nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor math_function)

paddle/fluid/operators/math/sequence_pooling.cc

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,31 @@ class FirstSeqPoolFunctor {
157157
}
158158
};
159159

160+
template <typename T>
161+
class SumSeqPoolGradFunctor {
162+
public:
163+
void operator()(const platform::CPUDeviceContext& context,
164+
const framework::Tensor& out_grad,
165+
framework::LoDTensor* in_grad) {
166+
auto lod = in_grad->lod()[0];
167+
int64_t out_w = out_grad.numel() / out_grad.dims()[0];
168+
int64_t in_w = in_grad->numel() / in_grad->dims()[0];
169+
PADDLE_ENFORCE(in_w == out_w);
170+
const T* out_g_data = out_grad.data<T>();
171+
T* in_g_data = in_grad->mutable_data<T>(context.GetPlace());
172+
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
173+
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
174+
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
175+
int64_t in_offset = lod[i] * in_w;
176+
const T* out_pos = out_g_data + i * out_w;
177+
T* in_pos = in_g_data + in_offset;
178+
for (int r = 0; r != h; ++r) {
179+
blas.VCOPY(in_w, out_pos, in_pos + r * in_w);
180+
}
181+
}
182+
}
183+
};
184+
160185
template <typename T>
161186
class SequencePoolFunctor<platform::CPUDeviceContext, T> {
162187
public:
@@ -231,9 +256,15 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
231256
math::SetConstant<platform::CPUDeviceContext, T> functor;
232257
functor(context, in_grad, 0);
233258
}
259+
260+
if (pooltype == "SUM") {
261+
math::SumSeqPoolGradFunctor<T> sum_pool_grad;
262+
sum_pool_grad(context, out_grad, in_grad);
263+
return;
264+
}
265+
234266
auto lod = in_grad->lod()[0];
235267
auto& place = *context.eigen_device();
236-
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
237268
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
238269
auto in_g_t = in_grad->Slice(static_cast<int>(lod[i]),
239270
static_cast<int>(lod[i + 1]));
@@ -247,12 +278,6 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
247278

248279
if (pooltype == "AVERAGE") {
249280
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
250-
} else if (pooltype == "SUM") {
251-
const T* out_g_data = out_g_t.data<T>();
252-
T* in_g_data = in_g_t.mutable_data<T>(context.GetPlace());
253-
for (int r = 0; r != h; ++r) {
254-
blas.VCOPY(w, out_g_data, in_g_data + r * w);
255-
}
256281
} else if (pooltype == "SQRT") {
257282
in_g_e.device(place) =
258283
(out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/math/sequence_pooling.h"
16+
#include <gtest/gtest.h>
17+
#include <vector>
18+
19+
template <typename DeviceContext, typename Place, typename T>
20+
void TestSequencePoolingSum(const paddle::framework::LoD& lod) {
21+
paddle::framework::LoDTensor cpu_out_grad;
22+
paddle::framework::LoDTensor cpu_in_grad;
23+
paddle::framework::LoDTensor out_grad;
24+
paddle::framework::LoDTensor in_grad;
25+
const size_t second_dim = 128u;
26+
27+
// construct out_grad's tensor in cpu
28+
const size_t out_first_dim = lod[0].size() - 1;
29+
auto out_dims = paddle::framework::make_ddim(
30+
{static_cast<int64_t>(out_first_dim), static_cast<int64_t>(second_dim)});
31+
32+
cpu_out_grad.mutable_data<T>(out_dims, paddle::platform::CPUPlace());
33+
for (int64_t i = 0; i < cpu_out_grad.numel(); ++i) {
34+
cpu_out_grad.data<T>()[i] = static_cast<T>(i);
35+
}
36+
37+
// copy to dst out_grad
38+
auto* place = new Place();
39+
DeviceContext* context = new DeviceContext(*place);
40+
if (paddle::platform::is_cpu_place(*place)) {
41+
out_grad = cpu_out_grad;
42+
} else {
43+
TensorCopySync(cpu_out_grad, *place, &out_grad);
44+
}
45+
46+
// construct in_grad
47+
in_grad.set_lod(lod);
48+
auto in_dims = paddle::framework::make_ddim(
49+
{static_cast<int64_t>(lod[0].back()), static_cast<int64_t>(second_dim)});
50+
in_grad.mutable_data<T>(in_dims, context->GetPlace());
51+
52+
// check tensor contruction result
53+
PADDLE_ENFORCE_EQ(in_grad.dims().size(), out_grad.dims().size());
54+
for (int64_t i = 1; i < out_grad.dims().size(); ++i) {
55+
PADDLE_ENFORCE_EQ(in_grad.dims()[i], out_grad.dims()[i]);
56+
}
57+
58+
// call functor
59+
paddle::operators::math::SequencePoolGradFunctor<DeviceContext, T>()(
60+
*context, "SUM", out_grad, &in_grad);
61+
62+
if (paddle::platform::is_cpu_place(*place)) {
63+
cpu_in_grad = in_grad;
64+
} else {
65+
TensorCopySync(in_grad, paddle::platform::CPUPlace(), &cpu_in_grad);
66+
cpu_in_grad.set_lod(in_grad.lod());
67+
}
68+
69+
EXPECT_EQ(in_grad.numel(), lod[0].back() * second_dim);
70+
EXPECT_EQ(in_grad.lod(), lod);
71+
72+
if (paddle::platform::is_cpu_place(*place)) {
73+
for (int64_t i = 0; i < in_grad.lod()[0].size() - 1; ++i) {
74+
int64_t begin = in_grad.lod()[0][i];
75+
int64_t end = in_grad.lod()[0][i + 1];
76+
paddle::framework::Tensor tmp = in_grad.Slice(begin, end);
77+
for (int64_t j = 0; j != tmp.numel() / second_dim; ++j) {
78+
for (int64_t m = 0; m != second_dim; ++m) {
79+
EXPECT_EQ(tmp.data<T>()[m + j * second_dim],
80+
out_grad.data<T>()[m + i * second_dim]);
81+
}
82+
}
83+
}
84+
} else {
85+
for (int64_t i = 0; i < cpu_in_grad.lod()[0].size() - 1; ++i) {
86+
int64_t begin = cpu_in_grad.lod()[0][i];
87+
int64_t end = cpu_in_grad.lod()[0][i + 1];
88+
paddle::framework::Tensor tmp = cpu_in_grad.Slice(begin, end);
89+
for (int64_t j = 0; j != tmp.numel() / second_dim; ++j) {
90+
for (int64_t m = 0; m != second_dim; ++m) {
91+
EXPECT_EQ(tmp.data<T>()[m + j * second_dim],
92+
cpu_out_grad.data<T>()[m + i * second_dim]);
93+
}
94+
}
95+
}
96+
}
97+
98+
delete place;
99+
delete context;
100+
}
101+
102+
TEST(SequencePoolingGrad, CPU_SUM) {
103+
paddle::framework::LoD lod1;
104+
lod1.push_back(std::vector<size_t>{0, 10});
105+
TestSequencePoolingSum<paddle::platform::CPUDeviceContext,
106+
paddle::platform::CPUPlace, float>(lod1);
107+
108+
paddle::framework::LoD lod2;
109+
lod2.push_back(std::vector<size_t>{0, 2, 7, 10});
110+
TestSequencePoolingSum<paddle::platform::CPUDeviceContext,
111+
paddle::platform::CPUPlace, float>(lod2);
112+
}
113+
114+
#ifdef PADDLE_WITH_CUDA
115+
TEST(SequencePoolingGrad, CUDA_SUM) {
116+
paddle::framework::LoD lod1;
117+
lod1.push_back(std::vector<size_t>{0, 10});
118+
TestSequencePoolingSum<paddle::platform::CUDADeviceContext,
119+
paddle::platform::CUDAPlace, float>(lod1);
120+
121+
paddle::framework::LoD lod2;
122+
lod2.push_back(std::vector<size_t>{0, 2, 7, 10});
123+
TestSequencePoolingSum<paddle::platform::CUDADeviceContext,
124+
paddle::platform::CUDAPlace, float>(lod2);
125+
}
126+
#endif

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ if(WITH_DISTRIBUTE)
7878
set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200)
7979
py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext)
8080
set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000)
81-
# TODO: fix this test
82-
#py_test_modules(test_dist_transformer MODULES test_dist_transformer)
83-
#set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000)
81+
82+
py_test_modules(test_dist_transformer MODULES test_dist_transformer)
83+
set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000)
8484
endif(NOT APPLE)
8585
py_test_modules(test_dist_transpiler MODULES test_dist_transpiler)
8686
endif()

python/paddle/fluid/tests/unittests/dist_transformer.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import paddle.fluid as fluid
3636
import paddle.fluid.layers as layers
3737
from paddle.fluid import core
38-
from test_dist_base import TestDistRunnerBase, runtime_main
38+
from test_dist_base import TestDistRunnerBase, runtime_main, RUN_STEP
3939
import paddle.compat as cpt
4040
from paddle.compat import long_type
4141

@@ -562,18 +562,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
562562
for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
563563
pass_start_time = time.time()
564564
for batch_id, data in enumerate(train_data()):
565-
if batch_id >= 5:
565+
if batch_id >= RUN_STEP:
566566
break
567567

568568
feed_list = []
569569
total_num_token = 0
570570

571-
#if TrainTaskConfig.local:
572-
# lr_rate = lr_scheduler.update_learning_rate()
573-
#for place_id, data_buffer in enumerate(
574-
# split_data(
575-
# data, num_part=dev_count)):
576-
577571
if TrainTaskConfig.local:
578572
lr_rate = lr_scheduler.update_learning_rate()
579573

@@ -619,12 +613,11 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
619613
init = True
620614

621615
# Validate and save the model for inference.
622-
if batch_id == 0 or batch_id == 4:
623-
if TrainTaskConfig.val_file_pattern is not None:
624-
val_avg_cost, val_ppl = test()
625-
print("[%f]" % val_avg_cost)
626-
else:
627-
assert (False)
616+
if TrainTaskConfig.val_file_pattern is not None:
617+
val_avg_cost, val_ppl = test()
618+
print("[%f]" % val_avg_cost)
619+
else:
620+
assert (False)
628621

629622

630623
#import transformer_reader as reader
@@ -1701,7 +1694,7 @@ def run_pserver(self, args):
17011694

17021695
def run_trainer(self, args):
17031696
TrainTaskConfig.use_gpu = args.use_cuda
1704-
sum_cost, avg_cost, predict, token_num, local_lr_scheduler = get_model(
1697+
sum_cost, avg_cost, predict, token_num, local_lr_scheduler, test_program = get_model(
17051698
args.is_dist, not args.sync_mode)
17061699

17071700
if args.is_dist:

python/paddle/fluid/tests/unittests/test_dist_transformer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def _setup_config(self):
6161

6262
def test_dist_train(self):
6363
download_files()
64-
self.check_with_place("dist_transformer.py", delta=1e-5)
64+
self.check_with_place(
65+
"dist_transformer.py", delta=1e-5, check_error_log=False)
6566

6667

6768
class TestDistTransformer2x2Async(TestDistBase):
@@ -70,7 +71,8 @@ def _setup_config(self):
7071

7172
def test_dist_train(self):
7273
download_files()
73-
self.check_with_place("dist_transformer.py", delta=1.0)
74+
self.check_with_place(
75+
"dist_transformer.py", delta=1.0, check_error_log=False)
7476

7577

7678
if __name__ == "__main__":

0 commit comments

Comments
 (0)