Skip to content

Commit e84d3a7

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_parallel_executor_tests
2 parents 1de9ede + fee5b24 commit e84d3a7

22 files changed

+423
-208
lines changed

CMakeLists.txt

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_F
3939
option(WITH_AMD_GPU "Compile PaddlePaddle with AMD GPU" OFF)
4040
option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND})
4141
option(WITH_MKL "Compile PaddlePaddle with MKL support." ${AVX_FOUND})
42-
option(WITH_TENSORRT "Compile PaddlePaddle with TensorRT support." OFF)
4342
option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON)
4443
option(WITH_TESTING "Compile PaddlePaddle with unit testing" OFF)
4544
option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON)
@@ -180,13 +179,9 @@ set(EXTERNAL_LIBS
180179

181180
if(WITH_GPU)
182181
include(cuda)
182+
include(tensorrt)
183183
endif(WITH_GPU)
184184

185-
# TensorRT depends on GPU.
186-
if (NOT WITH_GPU)
187-
set(WITH_TENSORRT OFF)
188-
endif()
189-
190185
if(WITH_AMD_GPU)
191186
find_package(HIP)
192187
include(hip)

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ ENV PATH=${PATH}:${GOROOT}/bin:${GOPATH}/bin
4646
RUN curl -s -q https://glide.sh/get | sh
4747

4848
# Install TensorRT
49-
# The unnecessary files has been removed to make the library small.
49+
# The unnecessary files has been removed to make the library small. It only contains include and lib now.
5050
RUN wget -qO- http://paddlepaddledeps.bj.bcebos.com/TensorRT-4.0.0.3.Ubuntu-16.04.4.x86_64-gnu.cuda-8.0.cudnn7.0.tar.gz | \
5151
tar -xz -C /usr/local && \
5252
cp -rf /usr/local/TensorRT/include /usr && \

cmake/configure.cmake

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,16 @@ if(WITH_GPU)
8080
# Include cuda and cudnn
8181
include_directories(${CUDNN_INCLUDE_DIR})
8282
include_directories(${CUDA_TOOLKIT_INCLUDE})
83+
84+
if(TENSORRT_FOUND)
85+
if(${CUDA_VERSION_MAJOR} VERSION_LESS 8)
86+
message(FATAL_ERROR "TensorRT needs CUDA >= 8.0 to compile")
87+
endif()
88+
if(${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
89+
message(FATAL_ERROR "TensorRT needs CUDNN >= 7.0 to compile")
90+
endif()
91+
include_directories(${TENSORRT_INCLUDE_DIR})
92+
endif()
8393
elseif(WITH_AMD_GPU)
8494
add_definitions(-DPADDLE_WITH_HIP)
8595
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__HIP_PLATFORM_HCC__")

cmake/tensorrt.cmake

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
if(NOT WITH_GPU)
2+
return()
3+
endif()
4+
5+
set(TENSORRT_ROOT "/usr" CACHE PATH "TENSORRT ROOT")
6+
find_path(TENSORRT_INCLUDE_DIR NvInfer.h
7+
PATHS ${TENSORRT_ROOT} ${TENSORRT_ROOT}/include
8+
$ENV{TENSORRT_ROOT} $ENV{TENSORRT_ROOT}/include
9+
NO_DEFAULT_PATH
10+
)
11+
12+
find_library(TENSORRT_LIBRARY NAMES libnvinfer.so libnvinfer.a
13+
PATHS ${TENSORRT_ROOT} ${TENSORRT_ROOT}/lib
14+
$ENV{TENSORRT_ROOT} $ENV{TENSORRT_ROOT}/lib
15+
NO_DEFAULT_PATH
16+
DOC "Path to TensorRT library.")
17+
18+
if(TENSORRT_INCLUDE_DIR AND TENSORRT_LIBRARY)
19+
set(TENSORRT_FOUND ON)
20+
else()
21+
set(TENSORRT_FOUND OFF)
22+
endif()
23+
24+
if(TENSORRT_FOUND)
25+
file(READ ${TENSORRT_INCLUDE_DIR}/NvInfer.h TENSORRT_VERSION_FILE_CONTENTS)
26+
string(REGEX MATCH "define NV_TENSORRT_MAJOR +([0-9]+)" TENSORRT_MAJOR_VERSION
27+
"${TENSORRT_VERSION_FILE_CONTENTS}")
28+
string(REGEX REPLACE "define NV_TENSORRT_MAJOR +([0-9]+)" "\\1"
29+
TENSORRT_MAJOR_VERSION "${TENSORRT_MAJOR_VERSION}")
30+
31+
message(STATUS "Current TensorRT header is ${TENSORRT_INCLUDE_DIR}/NvInfer.h. "
32+
"Current TensorRT version is v${TENSORRT_MAJOR_VERSION}. ")
33+
endif()

paddle/fluid/framework/parallel_executor.cc

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,9 @@ void ParallelExecutor::BCastParamsToGPUs(
155155
#endif
156156
}
157157

158-
void ParallelExecutor::Run(
159-
const std::vector<std::string> &fetch_tensors,
160-
const std::string &fetched_var_name,
161-
const std::unordered_map<std::string, LoDTensor> &feed_tensors) {
158+
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
159+
const std::string &fetched_var_name) {
162160
platform::RecordBlock b(0);
163-
SplitTensorToPlaces(feed_tensors);
164-
165161
// Create local scopes.
166162
for (auto &scope : member_->local_scopes_) {
167163
Scope &local_scope = scope->NewScope();
@@ -195,14 +191,28 @@ void ParallelExecutor::Run(
195191
auto &local_scope =
196192
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
197193
scope->DeleteScope(local_scope);
198-
local_scope = nullptr;
199194
}
200195
}
201196

202-
void ParallelExecutor::SplitTensorToPlaces(
203-
const std::unordered_map<std::string, LoDTensor> &feed_tensors) {
204-
for (auto it : feed_tensors) {
205-
auto lod_tensors = it.second.SplitLoDTensor(member_->places_);
197+
void ParallelExecutor::FeedTensorsIntoLocalScopes(
198+
const std::vector<std::unordered_map<std::string, LoDTensor>> &tensors) {
199+
PADDLE_ENFORCE_EQ(member_->local_scopes_.size(), tensors.size());
200+
201+
for (size_t i = 0; i < tensors.size(); ++i) {
202+
auto &map = tensors[i];
203+
auto *scope = member_->local_scopes_[i];
204+
for (auto &pair : map) {
205+
auto *trg = scope->Var(pair.first)->GetMutable<LoDTensor>();
206+
trg->ShareDataWith(pair.second);
207+
trg->set_lod(pair.second.lod());
208+
}
209+
}
210+
}
211+
212+
void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
213+
const std::unordered_map<std::string, LoDTensor> &tensors) {
214+
for (auto pair : tensors) {
215+
auto lod_tensors = pair.second.SplitLoDTensor(member_->places_);
206216
PADDLE_ENFORCE_EQ(
207217
member_->places_.size(), lod_tensors.size(),
208218
"The number of samples of current batch is less than the count of "
@@ -211,7 +221,7 @@ void ParallelExecutor::SplitTensorToPlaces(
211221
for (size_t j = 0; j < member_->places_.size(); ++j) {
212222
// TODO(panxy0718): Do I need to delete this var?
213223
auto t =
214-
member_->local_scopes_[j]->Var(it.first)->GetMutable<LoDTensor>();
224+
member_->local_scopes_[j]->Var(pair.first)->GetMutable<LoDTensor>();
215225
t->ShareDataWith(lod_tensors[j]);
216226
t->set_lod(lod_tensors[j].lod());
217227
}

paddle/fluid/framework/parallel_executor.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,22 @@ class ParallelExecutor {
4444

4545
std::vector<Scope*>& GetLocalScopes();
4646

47+
/**
48+
* Feed tensors to local scopes. The size of tensors should be equal to the
49+
* size of local scopes.
50+
*/
51+
void FeedTensorsIntoLocalScopes(
52+
const std::vector<std::unordered_map<std::string, LoDTensor>>& tensors);
53+
54+
void FeedAndSplitTensorIntoLocalScopes(
55+
const std::unordered_map<std::string, LoDTensor>& tensors);
56+
4757
void Run(const std::vector<std::string>& fetch_tensors,
48-
const std::string& fetched_var_name,
49-
const std::unordered_map<std::string, LoDTensor>& feed_tensors);
58+
const std::string& fetched_var_name);
5059

5160
void BCastParamsToGPUs(const std::unordered_set<std::string>& vars) const;
5261

5362
private:
54-
void SplitTensorToPlaces(
55-
const std::unordered_map<std::string, LoDTensor>& feed_tensors);
56-
5763
ParallelExecutorPrivate* member_;
5864
};
5965

paddle/fluid/inference/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ endif()
2121

2222
if(WITH_TESTING)
2323
add_subdirectory(tests/book)
24-
if (WITH_TENSORRT)
24+
if (TENSORRT_FOUND)
2525
add_subdirectory(tensorrt)
2626
endif()
2727
endif()

paddle/fluid/platform/dynload/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce)
22

33
list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc nccl.cc)
4-
if (WITH_TENSORRT)
4+
if (TENSORRT_FOUND)
55
list(APPEND CUDA_SRCS tensorrt.cc)
66
endif()
77

paddle/fluid/pybind/pybind.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,11 +505,19 @@ All parameter, weight, gradient are variables in Paddle.
505505
scope, local_scopes, allow_op_delay);
506506
})
507507
.def("bcast_params", &ParallelExecutor::BCastParamsToGPUs)
508+
// NOTE: even we return a vec<Scope*>* to Python use reference policy.
509+
// We still cannot get local_scope from this vector, since the element
510+
// of vec<Scope*> will be freed by Python GC. We can only return Scope*
511+
// one by one and mark them as reference.
508512
.def("local_scopes",
509513
[](ParallelExecutor &self) -> std::vector<Scope *> * {
510514
return &self.GetLocalScopes();
511515
},
512516
py::return_value_policy::reference)
517+
.def("feed_tensors_into_local_scopes",
518+
&ParallelExecutor::FeedTensorsIntoLocalScopes)
519+
.def("feed_and_split_tensor_into_local_scopes",
520+
&ParallelExecutor::FeedAndSplitTensorIntoLocalScopes)
513521
.def("run", &ParallelExecutor::Run);
514522

515523
BindRecordIOWriter(&m);

paddle/fluid/pybind/tensor_py.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@ void PyCUDATensorSetFromArray(
190190
static_cast<const platform::CUDADeviceContext *>(pool.Get(place));
191191
paddle::platform::GpuMemcpyAsync(dst, array.data(), sizeof(T) * array.size(),
192192
cudaMemcpyHostToDevice, dev_ctx->stream());
193+
// NOTE: For safety, here wait the copy complete.
194+
// It because the CPU array.data() could be destroyed after this method.
195+
// If we make this method async, it could be copied data from a memory buffer
196+
// that has been freed.
197+
dev_ctx->Wait();
193198
}
194199

195200
template <>
@@ -216,6 +221,11 @@ void PyCUDATensorSetFromArray(
216221
paddle::platform::GpuMemcpyAsync(dst, array.data(),
217222
sizeof(uint16_t) * array.size(),
218223
cudaMemcpyHostToDevice, dev_ctx->stream());
224+
// NOTE: For safety, here wait the copy complete.
225+
// It because the CPU array.data() could be destroyed after this method.
226+
// If we make this method async, it could be copied data from a memory buffer
227+
// that has been freed.
228+
dev_ctx->Wait();
219229
}
220230

221231
template <typename T>

0 commit comments

Comments
 (0)