Skip to content

Commit bddd4bc

Browse files
committed
Merge remote-tracking branch 'origin/develop' into memory/stable
2 parents da8adf1 + 437debf commit bddd4bc

File tree

8 files changed

+153
-41
lines changed

8 files changed

+153
-41
lines changed

cmake/tensorrt.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ find_library(TENSORRT_LIBRARY NAMES libnvinfer.so libnvinfer.a
1616
DOC "Path to TensorRT library.")
1717

1818
if(TENSORRT_INCLUDE_DIR AND TENSORRT_LIBRARY)
19+
if(WITH_DSO)
1920
set(TENSORRT_FOUND ON)
21+
endif(WITH DSO)
2022
else()
2123
set(TENSORRT_FOUND OFF)
2224
endif()

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ struct LSTM : public PatternBase {
429429

430430
struct GRU : public PatternBase {
431431
GRU(PDPattern* pattern, const std::string& name_scope)
432-
: PatternBase(pattern, name_scope, "lstm") {}
432+
: PatternBase(pattern, name_scope, "gru") {}
433433

434434
PDNode* operator()(PDNode* x);
435435

paddle/fluid/operators/distributed/grpc_client.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
125125
VarHandlePtr h(new VarHandle(ep, "Get", var_name_val, p_ctx, p_scope));
126126
s->Prepare(h, time_out);
127127

128-
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] {
128+
framework::AsyncIO([var_name_val, s, this] {
129129
// prepare input
130130
sendrecv::VariableMessage req;
131131
req.set_varname(var_name_val);
@@ -166,7 +166,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
166166
s->Prepare(h, time_out);
167167

168168
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
169-
time_out, s, this] {
169+
s, this] {
170170
auto* var = p_scope->FindVar(in_var_name_val);
171171

172172
::grpc::ByteBuffer req;

paddle/fluid/operators/math/sequence_pooling.cc

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,58 @@ class MaxSeqPoolGradFunctor {
103103
}
104104
};
105105

106+
template <typename T>
107+
class LastSeqPoolFunctor {
108+
public:
109+
void operator()(const platform::CPUDeviceContext& context,
110+
const framework::LoDTensor& input,
111+
framework::Tensor* output) {
112+
// Create pointers to input and output data
113+
auto* in_data = input.data<T>();
114+
auto* out_data = output->data<T>();
115+
116+
// Calculate the size of each item in sequence
117+
int64_t item_size = input.numel() / input.dims()[0];
118+
auto lod = input.lod()[0];
119+
int seq_num = static_cast<int>(lod.size()) - 1;
120+
for (int i = 0; i < seq_num; ++i) {
121+
// Calculate the length of each sequence
122+
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
123+
// Point to the begin of next sequence
124+
in_data += seq_len * item_size;
125+
// Copy the last item of sequence to output
126+
std::memcpy(out_data, (in_data - item_size), item_size * sizeof(T));
127+
out_data += item_size;
128+
}
129+
}
130+
};
131+
132+
template <typename T>
133+
class FirstSeqPoolFunctor {
134+
public:
135+
void operator()(const platform::CPUDeviceContext& context,
136+
const framework::LoDTensor& input,
137+
framework::Tensor* output) {
138+
// Create pointers to input and output data
139+
auto* in_data = input.data<T>();
140+
auto* out_data = output->data<T>();
141+
142+
// Calculate the size of each item in sequence
143+
int64_t item_size = input.numel() / input.dims()[0];
144+
auto lod = input.lod()[0];
145+
int seq_num = static_cast<int>(lod.size()) - 1;
146+
for (int i = 0; i < seq_num; ++i) {
147+
// Calculate the length of each sequence
148+
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
149+
// Copy the first item of sequence to output
150+
std::memcpy(out_data, in_data, item_size * sizeof(T));
151+
// Point to the next sequence
152+
in_data += seq_len * item_size;
153+
out_data += item_size;
154+
}
155+
}
156+
};
157+
106158
template <typename T>
107159
class SequencePoolFunctor<platform::CPUDeviceContext, T> {
108160
public:
@@ -116,6 +168,16 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
116168
max_pool(context, input, output, index);
117169
return;
118170
}
171+
if (pooltype == "LAST") {
172+
math::LastSeqPoolFunctor<T> last_pool;
173+
last_pool(context, input, output);
174+
return;
175+
}
176+
if (pooltype == "FIRST") {
177+
math::FirstSeqPoolFunctor<T> first_pool;
178+
first_pool(context, input, output);
179+
return;
180+
}
119181
auto lod = input.lod()[0];
120182
auto& place = *context.eigen_device();
121183
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
@@ -133,10 +195,6 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
133195
} else if (pooltype == "SQRT") {
134196
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
135197
std::sqrt(static_cast<T>(h));
136-
} else if (pooltype == "LAST") {
137-
out_e.device(place) = in_e.chip(h - 1, 0);
138-
} else if (pooltype == "FIRST") {
139-
out_e.device(place) = in_e.chip(0, 0);
140198
} else {
141199
PADDLE_THROW("unsupported pooling pooltype");
142200
}

paddle/scripts/paddle_build.sh

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ function print_usage() {
3333
${BLUE}single_test${NONE}: run a single unit test
3434
${BLUE}bind_test${NONE}: parallel tests bind to different GPU
3535
${BLUE}doc${NONE}: generate paddle documents
36+
${BLUE}gen_doc_lib${NONE}: generate paddle documents library
3637
${BLUE}html${NONE}: convert C++ source code into HTML
3738
${BLUE}dockerfile${NONE}: generate paddle release dockerfile
3839
${BLUE}capi${NONE}: generate paddle CAPI package
@@ -431,24 +432,60 @@ EOF
431432
linkchecker doc/v2/cn/html/index.html
432433
linkchecker doc/v2/api/en/html/index.html
433434

434-
if [[ "$TRAVIS_PULL_REQUEST" != "false" ]]; then exit 0; fi;
435+
# if [[ "$TRAVIS_PULL_REQUEST" != "false" ]]; then exit 0; fi;
436+
#
437+
# # Deploy to the the content server if its a "develop" or "release/version" branch
438+
# # The "develop_doc" branch is reserved to test full deploy process without impacting the real content.
439+
# if [ "$TRAVIS_BRANCH" == "develop_doc" ]; then
440+
# PPO_SCRIPT_BRANCH=develop
441+
# elif [[ "$TRAVIS_BRANCH" == "develop" || "$TRAVIS_BRANCH" =~ ^v|release/[[:digit:]]+\.[[:digit:]]+(\.[[:digit:]]+)?(-\S*)?$ ]]; then
442+
# PPO_SCRIPT_BRANCH=master
443+
# else
444+
# # Early exit, this branch doesn't require documentation build
445+
# return 0;
446+
# fi
447+
# # Fetch the paddlepaddle.org deploy_docs.sh from the appopriate branch
448+
# export DEPLOY_DOCS_SH=https://raw.githubusercontent.com/PaddlePaddle/PaddlePaddle.org/$PPO_SCRIPT_BRANCH/scripts/deploy/deploy_docs.sh
449+
# export PYTHONPATH=$PYTHONPATH:${PADDLE_ROOT}/build/python:/paddle/build/python
450+
# cd ..
451+
# curl $DEPLOY_DOCS_SH | bash -s $CONTENT_DEC_PASSWD $TRAVIS_BRANCH ${PADDLE_ROOT} ${PADDLE_ROOT}/build/doc/ ${PPO_SCRIPT_BRANCH}
452+
# cd -
453+
}
435454

436-
# Deploy to the the content server if its a "develop" or "release/version" branch
437-
# The "develop_doc" branch is reserved to test full deploy process without impacting the real content.
438-
if [ "$TRAVIS_BRANCH" == "develop_doc" ]; then
439-
PPO_SCRIPT_BRANCH=develop
440-
elif [[ "$TRAVIS_BRANCH" == "develop" || "$TRAVIS_BRANCH" =~ ^v|release/[[:digit:]]+\.[[:digit:]]+(\.[[:digit:]]+)?(-\S*)?$ ]]; then
441-
PPO_SCRIPT_BRANCH=master
442-
else
443-
# Early exit, this branch doesn't require documentation build
444-
return 0;
445-
fi
446-
# Fetch the paddlepaddle.org deploy_docs.sh from the appopriate branch
447-
export DEPLOY_DOCS_SH=https://raw.githubusercontent.com/PaddlePaddle/PaddlePaddle.org/$PPO_SCRIPT_BRANCH/scripts/deploy/deploy_docs.sh
448-
export PYTHONPATH=$PYTHONPATH:${PADDLE_ROOT}/build/python:/paddle/build/python
449-
cd ..
450-
curl $DEPLOY_DOCS_SH | bash -s $CONTENT_DEC_PASSWD $TRAVIS_BRANCH ${PADDLE_ROOT} ${PADDLE_ROOT}/build/doc/ ${PPO_SCRIPT_BRANCH}
451-
cd -
455+
function gen_doc_lib() {
456+
mkdir -p ${PADDLE_ROOT}/build
457+
cd ${PADDLE_ROOT}/build
458+
cat <<EOF
459+
========================================
460+
Building documentation library ...
461+
In /paddle/build
462+
========================================
463+
EOF
464+
cmake .. \
465+
-DCMAKE_BUILD_TYPE=Release \
466+
-DWITH_DOC=ON \
467+
-DWITH_GPU=OFF \
468+
-DWITH_MKL=OFF \
469+
-DWITH_FLUID_ONLY=ON
470+
471+
local LIB_TYPE=$1
472+
case $LIB_TYPE in
473+
full)
474+
# Build full Paddle Python module. Will timeout without caching 'copy_paddle_pybind' first
475+
make -j `nproc` gen_proto_py framework_py_proto copy_paddle_pybind paddle_python
476+
;;
477+
pybind)
478+
# Build paddle pybind library. Takes 49 minutes to build. Might timeout
479+
make -j `nproc` copy_paddle_pybind
480+
;;
481+
proto)
482+
# Even smaller library.
483+
make -j `nproc` framework_py_proto
484+
;;
485+
*)
486+
exit 0
487+
;;
488+
esac
452489
}
453490

454491
function gen_html() {
@@ -608,6 +645,9 @@ function main() {
608645
doc)
609646
gen_docs
610647
;;
648+
gen_doc_lib)
649+
gen_doc_lib $2
650+
;;
611651
html)
612652
gen_html
613653
;;

python/paddle/fluid/tests/unittests/dist_transformer.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class TrainTaskConfig(object):
9292
src_vocab_fpath = data_path + "vocab.bpe.32000"
9393
trg_vocab_fpath = data_path + "vocab.bpe.32000"
9494
train_file_pattern = data_path + "train.tok.clean.bpe.32000.en-de"
95-
val_file_pattern = data_path + "newstest2013.tok.bpe.32000.en-de"
95+
val_file_pattern = data_path + "newstest2013.tok.bpe.32000.en-de.cut"
9696
pool_size = 2000
9797
sort_type = None
9898
local = True
@@ -624,11 +624,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
624624
init = True
625625

626626
# Validate and save the model for inference.
627-
if TrainTaskConfig.val_file_pattern is not None:
628-
val_avg_cost, val_ppl = test()
629-
print("[%f]" % val_avg_cost)
630-
else:
631-
assert (False)
627+
if batch_id == 0 or batch_id == 4:
628+
if TrainTaskConfig.val_file_pattern is not None:
629+
val_avg_cost, val_ppl = test()
630+
print("[%f]" % val_avg_cost)
631+
else:
632+
assert (False)
632633

633634

634635
#import transformer_reader as reader
@@ -1701,8 +1702,9 @@ def run_pserver(self, args):
17011702
exe.run(startup_prog)
17021703
exe.run(pserver_prog)
17031704

1704-
def run_trainer(self, place, args):
1705-
1705+
def run_trainer(self, use_cuda, args):
1706+
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
1707+
TrainTaskConfig.use_gpu = use_cuda
17061708
sum_cost, avg_cost, predict, token_num, local_lr_scheduler = get_model(
17071709
args.is_dist, not args.sync_mode)
17081710

python/paddle/fluid/tests/unittests/test_dist_base.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,10 @@ def run_pserver(self, args):
6161
exe.run(startup_prog)
6262
exe.run(pserver_prog)
6363

64-
def run_trainer(self, place, args):
64+
def run_trainer(self, use_cuda, args):
6565
import paddle
6666
import paddle.fluid as fluid
67+
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
6768
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
6869
self.get_model(batch_size=2)
6970
if args.mem_opt:
@@ -91,7 +92,7 @@ def run_trainer(self, place, args):
9192
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
9293

9394
exe = fluid.ParallelExecutor(
94-
True,
95+
use_cuda,
9596
loss_name=avg_cost.name,
9697
exec_strategy=strategy,
9798
build_strategy=build_stra)
@@ -142,9 +143,8 @@ def runtime_main(test_class):
142143
if args.role == "pserver" and args.is_dist:
143144
model.run_pserver(args)
144145
else:
145-
p = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
146-
) else fluid.CPUPlace()
147-
model.run_trainer(p, args)
146+
use_cuda = True if core.is_compiled_with_cuda() else False
147+
model.run_trainer(use_cuda, args)
148148

149149

150150
import paddle.compat as cpt
@@ -225,11 +225,12 @@ def _wait_ps_ready(self, pid):
225225
def check_with_place(self, model_file, delta=1e-3, check_error_log=False):
226226
# TODO(typhoonzero): should auto adapt GPU count on the machine.
227227
required_envs = {
228-
"PATH": os.getenv("PATH"),
229-
"PYTHONPATH": os.getenv("PYTHONPATH"),
230-
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH"),
228+
"PATH": os.getenv("PATH", ""),
229+
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
230+
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
231231
"FLAGS_fraction_of_gpu_memory_to_use": "0.15",
232-
"FLAGS_cudnn_deterministic": "1"
232+
"FLAGS_cudnn_deterministic": "1",
233+
"CPU_NUM": "1"
233234
}
234235

235236
if check_error_log:

python/paddle/fluid/tests/unittests/test_dist_transformer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import print_function
1616

17+
import os
1718
import unittest
1819
import paddle
1920
from test_dist_base import TestDistBase
@@ -44,6 +45,14 @@ def download_files():
4445
test_url = url_prefix + 'newstest2013.tok.bpe.32000.en-de'
4546
test_md5 = '9dd74a266dbdb25314183899f269b4a2'
4647
paddle.dataset.common.download(test_url, 'test_dist_transformer', test_md5)
48+
# cut test data for faster CI
49+
orig_path = os.path.join(paddle.dataset.common.DATA_HOME,
50+
"test_dist_transformer",
51+
"newstest2013.tok.bpe.32000.en-de")
52+
head_path = os.path.join(paddle.dataset.common.DATA_HOME,
53+
"test_dist_transformer",
54+
"newstest2013.tok.bpe.32000.en-de.cut")
55+
os.system("head -n10 %s > %s" % (orig_path, head_path))
4756

4857

4958
class TestDistTransformer2x2Sync(TestDistBase):

0 commit comments

Comments
 (0)