Skip to content

Commit b3cf429

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix_grpc_server_ready_condition
2 parents e7ac709 + 1945b72 commit b3cf429

File tree

13 files changed

+227
-182
lines changed

13 files changed

+227
-182
lines changed

paddle/fluid/inference/tensorrt/engine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class TensorRTEngine : public EngineBase {
6565
// Initialize the inference network, so that TensorRT layers can add to this
6666
// network.
6767
void InitNetwork() {
68-
infer_builder_.reset(createInferBuilder(logger_));
68+
infer_builder_.reset(createInferBuilder(&logger_));
6969
infer_network_.reset(infer_builder_->createNetwork());
7070
}
7171
// After finishing adding ops, freeze this network and creates the executation

paddle/fluid/inference/tensorrt/helper.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@ const int kDataTypeSize[] = {
4646
// The following two API are implemented in TensorRT's header file, cannot load
4747
// from the dynamic library. So create our own implementation and directly
4848
// trigger the method from the dynamic library.
49-
static nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger& logger) {
49+
static nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger* logger) {
5050
return static_cast<nvinfer1::IBuilder*>(
51-
dy::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION));
51+
dy::createInferBuilder_INTERNAL(logger, NV_TENSORRT_VERSION));
5252
}
53-
static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger& logger) {
53+
static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) {
5454
return static_cast<nvinfer1::IRuntime*>(
55-
dy::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION));
55+
dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION));
5656
}
5757

5858
// A logger for create TensorRT infer builder.
@@ -80,7 +80,7 @@ class NaiveLogger : public nvinfer1::ILogger {
8080
return *x;
8181
}
8282

83-
virtual ~NaiveLogger() override {}
83+
~NaiveLogger() override {}
8484
};
8585

8686
} // namespace tensorrt

paddle/fluid/inference/tensorrt/test_tensorrt.cc

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include <cuda.h>
16+
#include <cuda_runtime_api.h>
1517
#include <glog/logging.h>
1618
#include <gtest/gtest.h>
1719
#include "NvInfer.h"
18-
#include "cuda.h"
19-
#include "cuda_runtime_api.h"
2020
#include "paddle/fluid/platform/dynload/tensorrt.h"
2121

2222
namespace dy = paddle::platform::dynload;
@@ -43,7 +43,7 @@ class Logger : public nvinfer1::ILogger {
4343

4444
class ScopedWeights {
4545
public:
46-
ScopedWeights(float value) : value_(value) {
46+
explicit ScopedWeights(float value) : value_(value) {
4747
w.type = nvinfer1::DataType::kFLOAT;
4848
w.values = &value_;
4949
w.count = 1;
@@ -58,13 +58,13 @@ class ScopedWeights {
5858
// The following two API are implemented in TensorRT's header file, cannot load
5959
// from the dynamic library. So create our own implementation and directly
6060
// trigger the method from the dynamic library.
61-
nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger& logger) {
61+
nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger* logger) {
6262
return static_cast<nvinfer1::IBuilder*>(
63-
dy::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION));
63+
dy::createInferBuilder_INTERNAL(logger, NV_TENSORRT_VERSION));
6464
}
65-
nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger& logger) {
65+
nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) {
6666
return static_cast<nvinfer1::IRuntime*>(
67-
dy::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION));
67+
dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION));
6868
}
6969

7070
const char* kInputTensor = "input";
@@ -74,7 +74,7 @@ const char* kOutputTensor = "output";
7474
nvinfer1::IHostMemory* CreateNetwork() {
7575
Logger logger;
7676
// Create the engine.
77-
nvinfer1::IBuilder* builder = createInferBuilder(logger);
77+
nvinfer1::IBuilder* builder = createInferBuilder(&logger);
7878
ScopedWeights weights(2.);
7979
ScopedWeights bias(3.);
8080

@@ -103,9 +103,9 @@ nvinfer1::IHostMemory* CreateNetwork() {
103103
return model;
104104
}
105105

106-
void Execute(nvinfer1::IExecutionContext& context, const float* input,
106+
void Execute(nvinfer1::IExecutionContext* context, const float* input,
107107
float* output) {
108-
const nvinfer1::ICudaEngine& engine = context.getEngine();
108+
const nvinfer1::ICudaEngine& engine = context->getEngine();
109109
// Two binds, input and output
110110
ASSERT_EQ(engine.getNbBindings(), 2);
111111
const int input_index = engine.getBindingIndex(kInputTensor);
@@ -119,7 +119,7 @@ void Execute(nvinfer1::IExecutionContext& context, const float* input,
119119
// Copy the input to the GPU, execute the network, and copy the output back.
120120
ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float),
121121
cudaMemcpyHostToDevice, stream));
122-
context.enqueue(1, buffers, stream, nullptr);
122+
context->enqueue(1, buffers, stream, nullptr);
123123
ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float),
124124
cudaMemcpyDeviceToHost, stream));
125125
cudaStreamSynchronize(stream);
@@ -136,7 +136,7 @@ TEST(TensorrtTest, BasicFunction) {
136136

137137
// Use the model to create an engine and an execution context.
138138
Logger logger;
139-
nvinfer1::IRuntime* runtime = createInferRuntime(logger);
139+
nvinfer1::IRuntime* runtime = createInferRuntime(&logger);
140140
nvinfer1::ICudaEngine* engine =
141141
runtime->deserializeCudaEngine(model->data(), model->size(), nullptr);
142142
model->destroy();
@@ -145,7 +145,7 @@ TEST(TensorrtTest, BasicFunction) {
145145
// Execute the network.
146146
float input = 1234;
147147
float output;
148-
Execute(*context, &input, &output);
148+
Execute(context, &input, &output);
149149
EXPECT_EQ(output, input * 2 + 3);
150150

151151
// Destroy the engine.

paddle/fluid/operators/math/pooling.cc

Lines changed: 55 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ distributed under the License is distributed on an "AS IS" BASIS,
1111
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
14-
1514
#include "paddle/fluid/operators/math/pooling.h"
15+
#include <algorithm>
16+
#include <vector>
1617

1718
namespace paddle {
1819
namespace operators {
@@ -27,9 +28,10 @@ template <typename PoolProcess, typename T>
2728
class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
2829
public:
2930
void operator()(const platform::CPUDeviceContext& context,
30-
const framework::Tensor& input, std::vector<int>& ksize,
31-
std::vector<int>& strides, std::vector<int>& paddings,
32-
PoolProcess pool_process, framework::Tensor* output) {
31+
const framework::Tensor& input, const std::vector<int>& ksize,
32+
const std::vector<int>& strides,
33+
const std::vector<int>& paddings, PoolProcess pool_process,
34+
framework::Tensor* output) {
3335
const int batch_size = input.dims()[0];
3436
const int input_height = input.dims()[2];
3537
const int input_width = input.dims()[3];
@@ -63,11 +65,11 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
6365
T ele = pool_process.initial();
6466
for (int h = hstart; h < hend; ++h) {
6567
for (int w = wstart; w < wend; ++w) {
66-
pool_process.compute(ele, input_data[h * input_width + w]);
68+
pool_process.compute(input_data[h * input_width + w], &ele);
6769
}
6870
}
6971
int pool_size = (hend - hstart) * (wend - wstart);
70-
pool_process.finalize(ele, (static_cast<T>(pool_size)));
72+
pool_process.finalize(static_cast<T>(pool_size), &ele);
7173
output_data[ph * output_width + pw] = ele;
7274
}
7375
}
@@ -86,13 +88,12 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
8688
template <typename PoolProcess, class T>
8789
class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
8890
public:
89-
void operator()(const platform::CPUDeviceContext& context,
90-
const framework::Tensor& input,
91-
const framework::Tensor& output,
92-
const framework::Tensor& output_grad, std::vector<int>& ksize,
93-
std::vector<int>& strides, std::vector<int>& paddings,
94-
PoolProcess pool_grad_process,
95-
framework::Tensor* input_grad) {
91+
void operator()(
92+
const platform::CPUDeviceContext& context, const framework::Tensor& input,
93+
const framework::Tensor& output, const framework::Tensor& output_grad,
94+
const std::vector<int>& ksize, const std::vector<int>& strides,
95+
const std::vector<int>& paddings, PoolProcess pool_grad_process,
96+
framework::Tensor* input_grad) {
9697
const int batch_size = input.dims()[0];
9798
const int input_height = input.dims()[2];
9899
const int input_width = input.dims()[3];
@@ -131,8 +132,8 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
131132
input_data[h * input_width + w],
132133
output_data[ph * output_width + pw],
133134
output_grad_data[ph * output_width + pw],
134-
input_grad_data[h * input_width + w],
135-
static_cast<T>(scale));
135+
static_cast<T>(scale),
136+
input_grad_data + h * input_width + w);
136137
}
137138
}
138139
}
@@ -154,12 +155,11 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
154155
template <class T>
155156
class MaxPool2dGradFunctor<platform::CPUDeviceContext, T> {
156157
public:
157-
void operator()(const platform::CPUDeviceContext& context,
158-
const framework::Tensor& input,
159-
const framework::Tensor& output,
160-
const framework::Tensor& output_grad, std::vector<int>& ksize,
161-
std::vector<int>& strides, std::vector<int>& paddings,
162-
framework::Tensor* input_grad) {
158+
void operator()(
159+
const platform::CPUDeviceContext& context, const framework::Tensor& input,
160+
const framework::Tensor& output, const framework::Tensor& output_grad,
161+
const std::vector<int>& ksize, const std::vector<int>& strides,
162+
const std::vector<int>& paddings, framework::Tensor* input_grad) {
163163
const int batch_size = input.dims()[0];
164164
const int input_height = input.dims()[2];
165165
const int input_width = input.dims()[3];
@@ -246,9 +246,10 @@ template <typename PoolProcess, class T>
246246
class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
247247
public:
248248
void operator()(const platform::CPUDeviceContext& context,
249-
const framework::Tensor& input, std::vector<int>& ksize,
250-
std::vector<int>& strides, std::vector<int>& paddings,
251-
PoolProcess pool_process, framework::Tensor* output) {
249+
const framework::Tensor& input, const std::vector<int>& ksize,
250+
const std::vector<int>& strides,
251+
const std::vector<int>& paddings, PoolProcess pool_process,
252+
framework::Tensor* output) {
252253
const int batch_size = input.dims()[0];
253254
const int input_depth = input.dims()[2];
254255
const int input_height = input.dims()[3];
@@ -293,14 +294,14 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
293294
for (int h = hstart; h < hend; ++h) {
294295
for (int w = wstart; w < wend; ++w) {
295296
pool_process.compute(
296-
ele,
297-
input_data[(d * input_height + h) * input_width + w]);
297+
input_data[(d * input_height + h) * input_width + w],
298+
&ele);
298299
}
299300
}
300301
}
301302
int pool_size =
302303
(dend - dstart) * (hend - hstart) * (wend - wstart);
303-
pool_process.finalize(ele, static_cast<T>(pool_size));
304+
pool_process.finalize(static_cast<T>(pool_size), &ele);
304305
output_data[output_idx] = ele;
305306
}
306307
}
@@ -320,13 +321,12 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
320321
template <typename PoolProcess, class T>
321322
class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
322323
public:
323-
void operator()(const platform::CPUDeviceContext& context,
324-
const framework::Tensor& input,
325-
const framework::Tensor& output,
326-
const framework::Tensor& output_grad, std::vector<int>& ksize,
327-
std::vector<int>& strides, std::vector<int>& paddings,
328-
PoolProcess pool_grad_process,
329-
framework::Tensor* input_grad) {
324+
void operator()(
325+
const platform::CPUDeviceContext& context, const framework::Tensor& input,
326+
const framework::Tensor& output, const framework::Tensor& output_grad,
327+
const std::vector<int>& ksize, const std::vector<int>& strides,
328+
const std::vector<int>& paddings, PoolProcess pool_grad_process,
329+
framework::Tensor* input_grad) {
330330
const int batch_size = input.dims()[0];
331331
const int input_depth = input.dims()[2];
332332
const int input_height = input.dims()[3];
@@ -379,8 +379,8 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
379379
(pd * output_height + ph) * output_width + pw;
380380
pool_grad_process.compute(
381381
input_data[input_idx], output_data[output_idx],
382-
output_grad_data[output_idx],
383-
input_grad_data[input_idx], static_cast<T>(scale));
382+
output_grad_data[output_idx], static_cast<T>(scale),
383+
input_grad_data + input_idx);
384384
}
385385
}
386386
}
@@ -404,12 +404,11 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
404404
template <class T>
405405
class MaxPool3dGradFunctor<platform::CPUDeviceContext, T> {
406406
public:
407-
void operator()(const platform::CPUDeviceContext& context,
408-
const framework::Tensor& input,
409-
const framework::Tensor& output,
410-
const framework::Tensor& output_grad, std::vector<int>& ksize,
411-
std::vector<int>& strides, std::vector<int>& paddings,
412-
framework::Tensor* input_grad) {
407+
void operator()(
408+
const platform::CPUDeviceContext& context, const framework::Tensor& input,
409+
const framework::Tensor& output, const framework::Tensor& output_grad,
410+
const std::vector<int>& ksize, const std::vector<int>& strides,
411+
const std::vector<int>& paddings, framework::Tensor* input_grad) {
413412
const int batch_size = input.dims()[0];
414413
const int input_depth = input.dims()[2];
415414
const int input_height = input.dims()[3];
@@ -510,9 +509,10 @@ template <typename T1, typename T2>
510509
class MaxPool2dWithIndexFunctor<platform::CPUDeviceContext, T1, T2> {
511510
public:
512511
void operator()(const platform::CPUDeviceContext& context,
513-
const framework::Tensor& input, std::vector<int>& ksize,
514-
std::vector<int>& strides, std::vector<int>& paddings,
515-
framework::Tensor* output, framework::Tensor* mask) {
512+
const framework::Tensor& input, const std::vector<int>& ksize,
513+
const std::vector<int>& strides,
514+
const std::vector<int>& paddings, framework::Tensor* output,
515+
framework::Tensor* mask) {
516516
const int batch_size = input.dims()[0];
517517
const int input_height = input.dims()[2];
518518
const int input_width = input.dims()[3];
@@ -576,8 +576,9 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUDeviceContext, T1, T2> {
576576
public:
577577
void operator()(const platform::CPUDeviceContext& context,
578578
const framework::Tensor& output_grad,
579-
const framework::Tensor& mask, std::vector<int>& ksize,
580-
std::vector<int>& strides, std::vector<int>& paddings,
579+
const framework::Tensor& mask, const std::vector<int>& ksize,
580+
const std::vector<int>& strides,
581+
const std::vector<int>& paddings,
581582
framework::Tensor* input_grad) {
582583
const int batch_size = input_grad->dims()[0];
583584
const int input_height = input_grad->dims()[2];
@@ -628,9 +629,10 @@ template <typename T1, typename T2>
628629
class MaxPool3dWithIndexFunctor<platform::CPUDeviceContext, T1, T2> {
629630
public:
630631
void operator()(const platform::CPUDeviceContext& context,
631-
const framework::Tensor& input, std::vector<int>& ksize,
632-
std::vector<int>& strides, std::vector<int>& paddings,
633-
framework::Tensor* output, framework::Tensor* mask) {
632+
const framework::Tensor& input, const std::vector<int>& ksize,
633+
const std::vector<int>& strides,
634+
const std::vector<int>& paddings, framework::Tensor* output,
635+
framework::Tensor* mask) {
634636
const int batch_size = input.dims()[0];
635637
const int input_depth = input.dims()[2];
636638
const int input_height = input.dims()[3];
@@ -708,8 +710,9 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUDeviceContext, T1, T2> {
708710
public:
709711
void operator()(const platform::CPUDeviceContext& context,
710712
const framework::Tensor& output_grad,
711-
const framework::Tensor& mask, std::vector<int>& ksize,
712-
std::vector<int>& strides, std::vector<int>& paddings,
713+
const framework::Tensor& mask, const std::vector<int>& ksize,
714+
const std::vector<int>& strides,
715+
const std::vector<int>& paddings,
713716
framework::Tensor* input_grad) {
714717
const int batch_size = input_grad->dims()[0];
715718
const int input_depth = input_grad->dims()[2];

0 commit comments

Comments
 (0)