Skip to content

Commit 674bd83

Browse files
authored
OpConverter change BlockDesc to proto::BlockDesc (#10623)
1 parent de81ccb commit 674bd83

File tree

7 files changed

+21
-16
lines changed

7 files changed

+21
-16
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES})
1+
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc op_converter.h DEPS ${FLUID_CORE_MODULES})
22
nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
33
DEPS ${FLUID_CORE_MODULES} activation_op tensorrt_engine)
44
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)

paddle/fluid/inference/tensorrt/convert/activation_op.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,18 @@ namespace tensorrt {
2121
class ReluOpConverter : public OpConverter {
2222
public:
2323
ReluOpConverter() {}
24-
void operator()(const framework::OpDesc& op) override {
24+
void operator()(const framework::proto::OpDesc& op) override {
25+
// Here the two nullptr looks strange, that's because the
26+
// framework::OpDesc's constructor is strange.
27+
framework::OpDesc op_desc(op, nullptr, nullptr);
2528
LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose "
2629
"type is Relu";
2730
const nvinfer1::ITensor* input_tensor =
28-
engine_->GetITensor(op.Input("X")[0]);
31+
engine_->GetITensor(op_desc.Input("X")[0]);
2932
nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER(
3033
engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor),
3134
nvinfer1::ActivationType::kRELU);
32-
engine_->SetITensor(op.Output("Out")[0], layer->getOutput(0));
35+
engine_->SetITensor(op_desc.Output("Out")[0], layer->getOutput(0));
3336
}
3437
};
3538

paddle/fluid/inference/tensorrt/convert/conv2d_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace tensorrt {
2121
class Conv2dOpConverter : public OpConverter {
2222
public:
2323
Conv2dOpConverter() {}
24-
void operator()(const framework::OpDesc& op) override {
24+
void operator()(const framework::proto::OpDesc& op) override {
2525
LOG(INFO)
2626
<< "convert a fluid conv2d op to tensorrt conv layer without bias";
2727
}

paddle/fluid/inference/tensorrt/convert/mul_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace tensorrt {
2121
class MulOpConverter : public OpConverter {
2222
public:
2323
MulOpConverter() {}
24-
void operator()(const framework::OpDesc& op) override {
24+
void operator()(const framework::proto::OpDesc& op) override {
2525
LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias";
2626
}
2727
};

paddle/fluid/inference/tensorrt/convert/op_converter.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,27 @@ namespace tensorrt {
3131
class OpConverter {
3232
public:
3333
OpConverter() {}
34-
virtual void operator()(const framework::OpDesc& op) {}
34+
virtual void operator()(const framework::proto::OpDesc& op) {}
3535

36-
void Run(const framework::OpDesc& op, TensorRTEngine* engine) {
37-
std::string type = op.Type();
36+
void Run(const framework::proto::OpDesc& op, TensorRTEngine* engine) {
37+
std::string type = op.type();
3838
auto* it = Registry<OpConverter>::Lookup(type);
3939
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", type);
4040
it->SetEngine(engine);
4141
(*it)(op);
4242
}
4343

4444
// convert fluid op to tensorrt layer
45-
void ConvertOp(const framework::OpDesc& op, TensorRTEngine* engine) {
45+
void ConvertOp(const framework::proto::OpDesc& op, TensorRTEngine* engine) {
4646
OpConverter::Run(op, engine);
4747
}
4848

4949
// convert fluid block to tensorrt network
50-
void ConvertBlock(const framework::BlockDesc& block, TensorRTEngine* engine) {
51-
for (auto op : block.AllOps()) {
52-
OpConverter::Run(*op, engine);
50+
void ConvertBlock(const framework::proto::BlockDesc& block,
51+
TensorRTEngine* engine) {
52+
for (size_t i = 0; i < block.ops_size(); i++) {
53+
const auto& op = block.ops(i);
54+
OpConverter::Run(op, engine);
5355
}
5456
}
5557

paddle/fluid/inference/tensorrt/convert/test_activation_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ void Compare(float input, float expect) {
4949
op_desc.SetInput("X", {"X"});
5050
op_desc.SetOutput("Out", {"Out"});
5151

52-
auto relu_op = framework::OpRegistry::CreateOp(op_desc);
52+
auto relu_op = framework::OpRegistry::CreateOp(*op_desc.Proto());
5353

5454
// run fluid op
5555
relu_op->Run(scope, place);
@@ -65,7 +65,7 @@ void Compare(float input, float expect) {
6565
nvinfer1::DimsCHW{1, 1, 1});
6666

6767
OpConverter op_converter;
68-
op_converter.ConvertOp(op_desc, engine);
68+
op_converter.ConvertOp(*op_desc.Proto(), engine);
6969

7070
engine->DeclareOutput("Out");
7171
engine->FreezeNetwork();

paddle/fluid/inference/tensorrt/convert/test_op_converter.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ TEST(OpConverter, ConvertBlock) {
2929
conv2d_op->SetType("conv2d");
3030

3131
OpConverter converter;
32-
converter.ConvertBlock(*block, nullptr /*TensorRTEngine*/);
32+
converter.ConvertBlock(*block->Proto(), nullptr /*TensorRTEngine*/);
3333
}
3434

3535
} // namespace tensorrt

0 commit comments

Comments
 (0)