Skip to content

Commit b4c826c

Browse files
committed
Merge remote-tracking branch 'ups/develop' into fea/jit/rnn
test=develop
2 parents ce31deb + a8d3aaa commit b4c826c

File tree

12 files changed

+189
-16
lines changed

12 files changed

+189
-16
lines changed

cmake/inference_lib.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ copy(framework_lib DEPS ${framework_lib_deps}
166166

167167
set(module "memory")
168168
copy(memory_lib
169-
SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/detail/*.h
170-
DSTS ${dst_dir}/${module} ${dst_dir}/${module}/detail
169+
SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/detail/*.h ${src_dir}/${module}/allocation/*.h
170+
DSTS ${dst_dir}/${module} ${dst_dir}/${module}/detail ${dst_dir}/${module}/allocation
171171
)
172172

173173
set(inference_deps paddle_fluid_shared paddle_fluid)

paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) {
4646
{"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
4747
"depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
4848
"elementwise_add", "elementwise_mul", "dropout", "split", "prelu",
49-
"conv2d_transpose"});
49+
"conv2d_transpose", "leaky_relu"});
5050
if (!node->IsOp()) return false;
5151

5252
if (teller_set.count(node->Op()->Type())) {

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,4 +551,5 @@ USE_TRT_CONVERTER(pad);
551551
USE_TRT_CONVERTER(split);
552552
USE_TRT_CONVERTER(prelu);
553553
USE_TRT_CONVERTER(conv2d_transpose);
554+
USE_TRT_CONVERTER(leaky_relu);
554555
#endif

paddle/fluid/inference/tensorrt/convert/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
nv_library(tensorrt_converter
33
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
44
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
5-
pad_op.cc split_op.cc prelu_op.cc
5+
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc
66
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
77

88
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
@@ -38,3 +38,5 @@ nv_test(test_trt_split_op SRCS test_split_op.cc split_op.cc
3838
nv_test(test_trt_prelu_op SRCS test_prelu_op.cc prelu_op.cc
3939
DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_plugin
4040
prelu_op SERIAL)
41+
nv_test(test_trt_leaky_relu_op SRCS test_leaky_relu_op.cc leaky_relu_op.cc
42+
DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine activation_op SERIAL)
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
16+
17+
namespace paddle {
18+
namespace inference {
19+
namespace tensorrt {
20+
21+
// LeakyRelu converter from fluid to tensorRT
22+
class LeakyReluOpConverter : public OpConverter {
23+
public:
24+
void operator()(const framework::proto::OpDesc& op,
25+
const framework::Scope& scope, bool test_mode) override {
26+
VLOG(4) << "convert fluid leaky_relu op to tensorrt layer";
27+
28+
framework::OpDesc op_desc(op, nullptr);
29+
// Declare inputs
30+
int input_num = op_desc.Input("X").size();
31+
PADDLE_ENFORCE(input_num == 1);
32+
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
33+
// Get output
34+
size_t output_num = op_desc.Output("Out").size();
35+
PADDLE_ENFORCE(output_num == 1);
36+
// Get attrs
37+
float alpha = boost::get<float>(op_desc.GetAttr("alpha"));
38+
39+
platform::CPUPlace place;
40+
std::unique_ptr<framework::LoDTensor> alpha_tensor(
41+
new framework::LoDTensor());
42+
alpha_tensor->Resize(framework::make_ddim({2}));
43+
float* alpha_data = alpha_tensor->mutable_data<float>(place);
44+
alpha_data[0] = alpha;
45+
alpha_data[1] = 1.f - alpha;
46+
// the leaky relu formula y = (x > 0) ? x : alpha * x is equal to
47+
// y = alpha * x + (x > 0) ? (1 - alpha) * x : 0
48+
TensorRTEngine::Weight scale{nvinfer1::DataType::kFLOAT, &alpha_data[0], 1};
49+
TensorRTEngine::Weight shift{nvinfer1::DataType::kFLOAT, nullptr, 0};
50+
TensorRTEngine::Weight power{nvinfer1::DataType::kFLOAT, nullptr, 0};
51+
// y_scale = alpha * x
52+
auto* scale_layer = TRT_ENGINE_ADD_LAYER(
53+
engine_, Scale, *input, nvinfer1::ScaleMode::kUNIFORM, shift.get(),
54+
scale.get(), power.get());
55+
PADDLE_ENFORCE(nullptr != scale_layer);
56+
// y_relu = (x > 0) : x : 0
57+
auto* relu_layer = TRT_ENGINE_ADD_LAYER(engine_, Activation, *input,
58+
nvinfer1::ActivationType::kRELU);
59+
PADDLE_ENFORCE(nullptr != relu_layer);
60+
//
61+
TensorRTEngine::Weight sub_scale{nvinfer1::DataType::kFLOAT, &alpha_data[1],
62+
1};
63+
auto* scale_relu_layer =
64+
TRT_ENGINE_ADD_LAYER(engine_, Scale, *(relu_layer->getOutput(0)),
65+
nvinfer1::ScaleMode::kUNIFORM, shift.get(),
66+
sub_scale.get(), power.get());
67+
PADDLE_ENFORCE(nullptr != scale_relu_layer);
68+
auto* output_layer =
69+
TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *(scale_layer->getOutput(0)),
70+
*(scale_relu_layer->getOutput(0)),
71+
nvinfer1::ElementWiseOperation::kSUM);
72+
PADDLE_ENFORCE(nullptr != output_layer);
73+
// keep alpha tensor to avoid release it's memory
74+
std::string alpha_name = op_desc.Output("Out")[0] + "_alpha";
75+
PADDLE_ENFORCE(engine_->weight_map.find(alpha_name) ==
76+
engine_->weight_map.end());
77+
engine_->weight_map[alpha_name] = std::move(alpha_tensor);
78+
79+
std::string layer_name = "leaky_relu (Output: ";
80+
auto output_name = op_desc.Output("Out")[0];
81+
output_layer->getOutput(0)->setName(output_name.c_str());
82+
engine_->SetITensor(output_name, output_layer->getOutput(0));
83+
layer_name += output_name;
84+
if (test_mode) {
85+
engine_->DeclareOutput(output_name);
86+
}
87+
output_layer->setName((layer_name + ")").c_str());
88+
}
89+
};
90+
91+
} // namespace tensorrt
92+
} // namespace inference
93+
} // namespace paddle
94+
95+
REGISTER_TRT_OP_CONVERTER(leaky_relu, LeakyReluOpConverter);
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <gtest/gtest.h>
16+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
17+
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
18+
19+
namespace paddle {
20+
namespace inference {
21+
namespace tensorrt {
22+
23+
TEST(leaky_relu_op, test_leaky_relu) {
24+
std::unordered_set<std::string> parameters;
25+
framework::Scope scope;
26+
TRTConvertValidation validator(10, parameters, scope, 1000);
27+
validator.DeclInputVar("leaky_relu_input", nvinfer1::DimsCHW(3, 2, 2));
28+
validator.DeclOutputVar("leaky_relu_out", nvinfer1::DimsCHW(3, 2, 2));
29+
30+
// Prepare Op description
31+
framework::OpDesc desc;
32+
desc.SetType("leaky_relu");
33+
desc.SetInput("X", {"leaky_relu_input"});
34+
desc.SetOutput("Out", {"leaky_relu_out"});
35+
36+
desc.SetAttr("alpha", 0.1f);
37+
38+
validator.SetOp(*desc.Proto());
39+
40+
validator.Execute(1);
41+
}
42+
43+
} // namespace tensorrt
44+
} // namespace inference
45+
} // namespace paddle
46+
47+
// USE_OP(leaky_relu);
48+
USE_OP(leaky_relu);
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
nv_library(tensorrt_plugin
22
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu
3-
DEPS enforce device_context)
3+
DEPS enforce tensorrt_engine)

paddle/fluid/inference/tests/api/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@ function(inference_analysis_api_test_with_fake_data target install_dir filename
2727
endfunction()
2828

2929
# RNN1
30-
if(NOT APPLE)
30+
if(NOT APPLE AND WITH_MKLML)
3131
set(RNN1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn1")
3232
download_model_and_data(${RNN1_INSTALL_DIR} "rnn1%2Fmodel.tar.gz" "rnn1%2Fdata.txt.tar.gz")
3333
inference_analysis_api_test(test_analyzer_rnn1 ${RNN1_INSTALL_DIR} analyzer_rnn1_tester.cc)
3434
else()
35-
# TODO: fix this test on MACOS, the reason is that
36-
# fusion_seqexpand_concat_fc_op is not supported on MACOS
37-
message(WARNING "These tests has been disabled in OSX before being fixed: \n test_analyzer_rnn1")
35+
# TODO: fix this test on MACOS and OPENBLAS, the reason is that
36+
# fusion_seqexpand_concat_fc_op is not supported on MACOS and OPENBLAS
37+
message(WARNING "These tests has been disabled in OSX or WITH_MKL=OFF before being fixed: \n test_analyzer_rnn1")
3838
endif()
3939

4040
# RNN2

paddle/fluid/operators/stack_op.h

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,20 +147,32 @@ class StackKernel : public framework::OpKernel<T> {
147147
auto &dim = x[0]->dims();
148148
for (auto i = 0; i < axis; ++i) pre *= dim[i];
149149
for (auto i = axis; i < dim.size(); ++i) post *= dim[i];
150-
int total_num = pre * n * post;
151150

152-
auto &dev_ctx = ctx.template device_context<DeviceContext>();
153151
#ifdef __NVCC__
152+
int total_num = pre * n * post;
153+
auto &dev_ctx = ctx.template device_context<DeviceContext>();
154+
154155
thrust::device_vector<const T *> device_x_vec(x_datas);
155156
auto x_data_arr = device_x_vec.data().get();
156-
#else
157-
auto x_data_arr = x_datas.data();
158-
#endif
157+
159158
StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post);
160-
#ifdef __NVCC__
159+
161160
// Wait() must be called because device_x_vec may be destructed before
162161
// kernel ends
163162
dev_ctx.Wait();
163+
#else
164+
auto x_data_arr = x_datas.data();
165+
166+
size_t x_offset = 0;
167+
size_t y_offset = 0;
168+
for (int i = 0; i < pre; i++) {
169+
for (int j = 0; j < n; j++) {
170+
std::memcpy(y_data + y_offset, x_data_arr[j] + x_offset,
171+
post * sizeof(T));
172+
y_offset += post;
173+
}
174+
x_offset += post;
175+
}
164176
#endif
165177
}
166178
};

paddle/fluid/platform/init.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ std::once_flag p2p_init_flag;
3838

3939
void InitGflags(std::vector<std::string> argv) {
4040
std::call_once(gflags_init_flag, [&]() {
41+
FLAGS_logtostderr = true;
4142
argv.insert(argv.begin(), "dummy");
4243
int argc = argv.size();
4344
char **arr = new char *[argv.size()];

0 commit comments

Comments
 (0)