Skip to content

Commit f5fc9c3

Browse files
authored
feature/mul converter (#10841)
1 parent 8f7b020 commit f5fc9c3

File tree

9 files changed

+242
-21
lines changed

9 files changed

+242
-21
lines changed

paddle/fluid/inference/analysis/helper.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@ namespace paddle {
2424
namespace inference {
2525
namespace analysis {
2626

27+
template <typename Vec>
28+
int AccuDims(Vec &&vec, int size) {
29+
int res = 1;
30+
for (int i = 0; i < size; i++) {
31+
res *= std::forward<Vec>(vec)[i];
32+
}
33+
return res;
34+
}
35+
2736
#define SET_TYPE(type__) dic_[typeid(type__).hash_code()] = #type__;
2837
/*
2938
* Map typeid to representation.
@@ -101,7 +110,5 @@ class OrderedRegistry {
101110
} // namespace paddle
102111

103112
#define PADDLE_DISALLOW_COPY_AND_ASSIGN(type__) \
104-
\
105113
type__(const type__ &) = delete; \
106-
\
107114
void operator=(const type__ &) = delete;
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES})
1+
# Add TRT tests
2+
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine)
23
# This test is not stable
34
# See https://paddleci.ngrok.io/viewLog.html?tab=buildLog&buildTypeId=Paddle_PrCi2&buildId=36834&_focus=8828
45
#nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc io_converter.cc
56
# DEPS ${FLUID_CORE_MODULES} activation_op tensorrt_engine
67
# SERIAL)
78
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
9+
nv_test(test_trt_mul_op SRCS test_mul_op.cc mul_op.cc
10+
DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL)

paddle/fluid/inference/tensorrt/convert/mul_op.cc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,25 @@ namespace paddle {
1818
namespace inference {
1919
namespace tensorrt {
2020

21+
/*
22+
* MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights.
23+
*/
2124
class MulOpConverter : public OpConverter {
2225
public:
2326
MulOpConverter() {}
2427
void operator()(const framework::proto::OpDesc& op) override {
25-
LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias";
28+
VLOG(4) << "convert a fluid mul op to tensorrt fc layer without bias";
29+
30+
framework::OpDesc op_desc(op, nullptr, nullptr);
31+
// Declare inputs
32+
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
33+
auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);
34+
// Both the input1 and input2 do not need transpose.
35+
auto* layer = TRT_ENGINE_ADD_LAYER(
36+
engine_, MatrixMultiply, *const_cast<nvinfer1::ITensor*>(input1), false,
37+
*const_cast<nvinfer1::ITensor*>(input2), false);
38+
39+
engine_->DeclareOutput(layer, 0, op_desc.Output("Out")[0]);
2640
}
2741
};
2842

paddle/fluid/inference/tensorrt/convert/test_activation_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,5 @@ TEST(OpConverter, ConvertRelu) {
102102
} // namespace tensorrt
103103
} // namespace inference
104104
} // namespace paddle
105+
106+
USE_OP(activation);
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <gtest/gtest.h>
16+
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
18+
19+
namespace paddle {
20+
namespace inference {
21+
namespace tensorrt {
22+
23+
TEST(MulOpConverter, main) {
24+
TRTConvertValidation validator(10, 1000);
25+
validator.DeclInputVar("mul-X", nvinfer1::Dims2(10, 6));
26+
validator.DeclInputVar("mul-Y", nvinfer1::Dims2(6, 10));
27+
validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(10, 10));
28+
29+
// Prepare Op description
30+
framework::OpDesc desc;
31+
desc.SetType("mul");
32+
desc.SetInput("X", {"mul-X"});
33+
desc.SetInput("Y", {"mul-Y"});
34+
desc.SetOutput("Out", {"mul-Out"});
35+
36+
LOG(INFO) << "set OP";
37+
validator.SetOp(*desc.Proto());
38+
LOG(INFO) << "execute";
39+
40+
validator.Execute(10);
41+
}
42+
43+
} // namespace tensorrt
44+
} // namespace inference
45+
} // namespace paddle
46+
47+
USE_OP(mul);

paddle/fluid/inference/tensorrt/convert/test_op_converter.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ namespace tensorrt {
2323
TEST(OpConverter, ConvertBlock) {
2424
framework::ProgramDesc prog;
2525
auto* block = prog.MutableBlock(0);
26-
auto* mul_op = block->AppendOp();
27-
mul_op->SetType("mul");
2826
auto* conv2d_op = block->AppendOp();
2927
conv2d_op->SetType("conv2d");
3028

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
/*
16+
* This file implements a UT framework to make the validation of transforming
17+
* Fluid Op to TRT Layer.
18+
*/
19+
20+
#pragma once
21+
22+
#include "paddle/fluid/framework/lod_tensor.h"
23+
#include "paddle/fluid/framework/op_registry.h"
24+
#include "paddle/fluid/inference/analysis/helper.h"
25+
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
26+
#include "paddle/fluid/inference/tensorrt/engine.h"
27+
28+
namespace paddle {
29+
namespace inference {
30+
namespace tensorrt {
31+
32+
/*
33+
* Get a random float value between [low, high]
34+
*/
35+
float random(float low, float high) {
36+
static std::random_device rd;
37+
static std::mt19937 mt(rd());
38+
std::uniform_real_distribution<double> dist(1.0, 10.0);
39+
return dist(mt);
40+
}
41+
42+
void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place,
43+
const platform::DeviceContext& ctx) {
44+
auto dims = tensor->dims();
45+
size_t num_elements = analysis::AccuDims(dims, dims.size());
46+
PADDLE_ENFORCE_GT(num_elements, 0);
47+
auto* data = tensor->mutable_data<float>(place);
48+
for (size_t i = 0; i < num_elements; i++) {
49+
*(data + i) = random(0., 1.);
50+
}
51+
}
52+
53+
/*
54+
* Help to validate the correctness between Fluid Op and the corresponding TRT
55+
* layer.
56+
*/
57+
class TRTConvertValidation {
58+
public:
59+
TRTConvertValidation() = delete;
60+
61+
TRTConvertValidation(int batch_size, int workspace_size = 1 << 10) {
62+
// create engine.
63+
engine_.reset(new TensorRTEngine(10, 1 << 10, &stream_));
64+
engine_->InitNetwork();
65+
66+
PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0);
67+
}
68+
69+
// Declare a Variable as input with random initialization.
70+
void DeclInputVar(const std::string& name, const nvinfer1::Dims& dims) {
71+
DeclVar(name, dims);
72+
// Declare TRT inputs.
73+
engine_->DeclareInput(name, nvinfer1::DataType::kFLOAT, dims);
74+
}
75+
76+
void DeclOutputVar(const std::string& name, const nvinfer1::Dims& dims) {
77+
DeclVar(name, dims);
78+
}
79+
80+
void DeclVar(const std::string& name, const nvinfer1::Dims& dims) {
81+
platform::CPUPlace place;
82+
platform::CPUDeviceContext ctx(place);
83+
84+
// Init Fluid tensor.
85+
std::vector<int> dim_vec(dims.nbDims);
86+
for (int i = 0; i < dims.nbDims; i++) {
87+
dim_vec[i] = dims.d[i];
88+
}
89+
auto* x = scope_.Var(name);
90+
auto* x_tensor = x->GetMutable<framework::LoDTensor>();
91+
x_tensor->Resize(framework::make_ddim(dim_vec));
92+
RandomizeTensor(x_tensor, place, ctx);
93+
}
94+
95+
void SetOp(const framework::proto::OpDesc& desc) {
96+
op_ = framework::OpRegistry::CreateOp(desc);
97+
98+
OpConverter op_converter;
99+
op_converter.ConvertOp(desc, engine_.get());
100+
101+
engine_->FreezeNetwork();
102+
103+
// Declare outputs.
104+
op_desc_.reset(new framework::OpDesc(desc, nullptr, nullptr));
105+
106+
// Set Inputs.
107+
for (const auto& input : op_desc_->InputArgumentNames()) {
108+
auto* var = scope_.FindVar(input);
109+
PADDLE_ENFORCE(var);
110+
auto tensor = var->GetMutable<framework::LoDTensor>();
111+
engine_->SetInputFromCPU(
112+
input, static_cast<void*>(tensor->data<float>()),
113+
sizeof(float) *
114+
analysis::AccuDims(tensor->dims(), tensor->dims().size()));
115+
}
116+
}
117+
118+
void Execute(int batch_size) {
119+
// Execute Fluid Op
120+
// Execute TRT
121+
platform::CPUPlace place;
122+
platform::CPUDeviceContext ctx(place);
123+
engine_->Execute(batch_size);
124+
125+
op_->Run(scope_, place);
126+
127+
ASSERT_FALSE(op_desc_->OutputArgumentNames().empty());
128+
for (const auto& output : op_desc_->OutputArgumentNames()) {
129+
std::vector<float> fluid_out;
130+
std::vector<float> trt_out(200);
131+
engine_->GetOutputInCPU(output, &trt_out[0], 200 * sizeof(float));
132+
133+
auto* var = scope_.FindVar(output);
134+
auto tensor = var->GetMutable<framework::LoDTensor>();
135+
framework::TensorToVector(*tensor, ctx, &fluid_out);
136+
// Compare two output
137+
ASSERT_FALSE(fluid_out.empty());
138+
for (size_t i = 0; i < fluid_out.size(); i++) {
139+
EXPECT_LT(std::abs(fluid_out[i] - trt_out[i]), 0.001);
140+
}
141+
}
142+
}
143+
144+
framework::Scope& scope() { return scope_; }
145+
146+
private:
147+
std::unique_ptr<TensorRTEngine> engine_;
148+
cudaStream_t stream_;
149+
framework::Scope scope_;
150+
std::unique_ptr<framework::OperatorBase> op_;
151+
std::unique_ptr<framework::OpDesc> op_desc_;
152+
};
153+
154+
} // namespace tensorrt
155+
} // namespace inference
156+
} // namespace paddle

paddle/fluid/inference/tensorrt/engine.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License. */
1818
#include <cuda.h>
1919
#include <glog/logging.h>
2020
#include <string>
21+
#include "paddle/fluid/inference/analysis/helper.h"
2122
#include "paddle/fluid/inference/tensorrt/helper.h"
2223
#include "paddle/fluid/platform/enforce.h"
2324

@@ -71,9 +72,10 @@ void TensorRTEngine::FreezeNetwork() {
7172
for (auto& item : buffer_sizes_) {
7273
if (item.second == 0) {
7374
auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str());
75+
auto dims = infer_engine_->getBindingDimensions(slot_offset);
7476
item.second = kDataTypeSize[static_cast<int>(
7577
infer_engine_->getBindingDataType(slot_offset))] *
76-
AccumDims(infer_engine_->getBindingDimensions(slot_offset));
78+
analysis::AccuDims(dims.d, dims.nbDims);
7779
}
7880
auto& buf = buffer(item.first);
7981
CHECK(buf.buffer == nullptr); // buffer should be allocated only once.
@@ -85,14 +87,15 @@ void TensorRTEngine::FreezeNetwork() {
8587

8688
nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name,
8789
nvinfer1::DataType dtype,
88-
const nvinfer1::Dims& dim) {
90+
const nvinfer1::Dims& dims) {
8991
PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s",
9092
name);
9193

9294
PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first");
93-
auto* input = infer_network_->addInput(name.c_str(), dtype, dim);
95+
auto* input = infer_network_->addInput(name.c_str(), dtype, dims);
9496
PADDLE_ENFORCE(input, "infer network add input %s failed", name);
95-
buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] * AccumDims(dim);
97+
buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] *
98+
analysis::AccuDims(dims.d, dims.nbDims);
9699
TensorRTEngine::SetITensor(name, input);
97100
return input;
98101
}
@@ -162,13 +165,13 @@ void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data,
162165
void TensorRTEngine::SetITensor(const std::string& name,
163166
nvinfer1::ITensor* tensor) {
164167
PADDLE_ENFORCE(tensor != nullptr);
165-
PADDLE_ENFORCE_EQ(0, itensor_map_.count(name), "duplicate itensor name %s",
168+
PADDLE_ENFORCE_EQ(0, itensor_map_.count(name), "duplicate ITensor name %s",
166169
name);
167170
itensor_map_[name] = tensor;
168171
}
169172

170173
nvinfer1::ITensor* TensorRTEngine::GetITensor(const std::string& name) {
171-
PADDLE_ENFORCE(itensor_map_.count(name), "no itensor %s", name);
174+
PADDLE_ENFORCE(itensor_map_.count(name), "no ITensor %s", name);
172175
return itensor_map_[name];
173176
}
174177

paddle/fluid/inference/tensorrt/helper.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,6 @@ namespace tensorrt {
2626

2727
namespace dy = paddle::platform::dynload;
2828

29-
static size_t AccumDims(nvinfer1::Dims dims) {
30-
size_t num = dims.nbDims == 0 ? 0 : 1;
31-
for (int i = 0; i < dims.nbDims; i++) {
32-
PADDLE_ENFORCE_GT(dims.d[i], 0);
33-
num *= dims.d[i];
34-
}
35-
return num;
36-
}
37-
3829
// TensorRT data type to size
3930
const int kDataTypeSize[] = {
4031
4, // kFLOAT

0 commit comments

Comments
 (0)