Skip to content

Commit 940f5db

Browse files
committed
modify the tensorrt engine op to adapt to chage
1 parent 8252769 commit 940f5db

File tree

3 files changed

+23
-24
lines changed

3 files changed

+23
-24
lines changed

paddle/fluid/operators/tensorrt_engine_op.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,14 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t> &shape) {
5353
PADDLE_ENFORCE_LE(shape.size(), 4UL,
5454
"TensorRT' tensor input requires at most 4 dimensions");
5555

56+
// We should delete the batch size here.
5657
switch (shape.size()) {
5758
case 2:
58-
return nvinfer1::Dims2(shape[0], shape[1]);
59+
return nvinfer1::Dims2(1, shape[1]);
5960
case 3:
60-
return nvinfer1::Dims3(shape[0], shape[1], shape[2]);
61+
return nvinfer1::Dims3(1, shape[1], shape[2]);
6162
case 4:
62-
return nvinfer1::Dims4(shape[0], shape[1], shape[2], shape[3]);
63+
return nvinfer1::Dims4(1, shape[1], shape[2], shape[3]);
6364
default:
6465
return nvinfer1::Dims();
6566
}

paddle/fluid/operators/tensorrt_engine_op.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,16 +95,14 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
9595
PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y);
9696
auto* fluid_t = fluid_v->GetMutable<framework::LoDTensor>();
9797
fluid_t->Resize(framework::make_ddim(ddim));
98-
auto size = inference::analysis::AccuDims(dims.d, dims.nbDims);
98+
9999
if (platform::is_cpu_place(fluid_t->place())) {
100100
// TODO(Superjomn) change this float to dtype size.
101101
engine->GetOutputInCPU(
102-
y, fluid_t->mutable_data<float>(platform::CPUPlace()),
103-
size * sizeof(float));
102+
y, fluid_t->mutable_data<float>(platform::CPUPlace()));
104103
} else {
105104
engine->GetOutputInGPU(
106-
y, fluid_t->mutable_data<float>(platform::CUDAPlace()),
107-
size * sizeof(float));
105+
y, fluid_t->mutable_data<float>(platform::CUDAPlace()));
108106
}
109107
}
110108

paddle/fluid/operators/tensorrt_engine_op_test.cc

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,36 +64,37 @@ TEST(TensorRTEngineOp, manual) {
6464

6565
LOG(INFO) << "create block desc";
6666
framework::BlockDesc block_desc(&program, block_);
67-
LOG(INFO) << "create mul op";
68-
auto* mul = block_desc.AppendOp();
69-
mul->SetType("mul");
70-
mul->SetInput("X", std::vector<std::string>({"x"})); // 2 x 4
71-
mul->SetInput("Y", std::vector<std::string>({"y"})); // 4 x 6
72-
mul->SetOutput("Out", std::vector<std::string>({"z"})); // 2 x 6
67+
LOG(INFO) << "create fc op";
68+
auto* fc0 = block_desc.AppendOp();
69+
fc0->SetType("mul");
70+
fc0->SetInput("X", std::vector<std::string>({"x"})); // 4 x 1 x 1
71+
fc0->SetInput("Y", std::vector<std::string>({"y"})); // 4 x 6
72+
fc0->SetOutput("Out", std::vector<std::string>({"z"})); // 6 x 1 x 1
7373

7474
LOG(INFO) << "create fc op";
75-
auto* fc = block_desc.AppendOp();
76-
fc->SetType("mul");
77-
fc->SetInput("X", std::vector<std::string>({"z"}));
78-
fc->SetInput("Y", std::vector<std::string>({"y0"})); // 6 x 8
79-
fc->SetOutput("Out", std::vector<std::string>({"z0"})); // 2 x 8
75+
auto* fc1 = block_desc.AppendOp();
76+
fc1->SetType("mul");
77+
fc1->SetInput("X", std::vector<std::string>({"z"}));
78+
fc1->SetInput("Y", std::vector<std::string>({"y0"})); // 6 x 8
79+
fc1->SetOutput("Out", std::vector<std::string>({"z0"})); // 8 x 1 x 1
8080

8181
// Set inputs' variable shape in BlockDesc
82-
AddTensorToBlockDesc(block_, "x", std::vector<int64_t>({2, 4}));
82+
// the batch size is 2, so the dims of 'x' is {2, 4, 1, 1}
83+
AddTensorToBlockDesc(block_, "x", std::vector<int64_t>({2, 4, 1, 1}));
8384
AddTensorToBlockDesc(block_, "y", std::vector<int64_t>({4, 6}));
8485
AddTensorToBlockDesc(block_, "y0", std::vector<int64_t>({6, 8}));
8586
AddTensorToBlockDesc(block_, "z", std::vector<int64_t>({2, 6}));
8687

8788
// It is wired, need to copy manually.
88-
*block_->add_ops() = *mul->Proto();
89-
*block_->add_ops() = *fc->Proto();
89+
*block_->add_ops() = *fc0->Proto();
90+
*block_->add_ops() = *fc1->Proto();
9091

9192
ASSERT_EQ(block_->ops_size(), 2);
9293

9394
LOG(INFO) << "create tensorrt desc";
9495
framework::OpDesc engine_op_desc(nullptr);
9596
engine_op_desc.SetType("tensorrt_engine");
96-
engine_op_desc.SetInput("Xs", std::vector<std::string>({"x", "y", "y0"}));
97+
engine_op_desc.SetInput("Xs", std::vector<std::string>({"x"}));
9798
engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z0"}));
9899
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
99100
block_->SerializeAsString());
@@ -208,4 +209,3 @@ TEST(TensorRTEngineOp, fc) { Execute(40, 28, 28); }
208209
} // namespace paddle
209210

210211
USE_TRT_CONVERTER(mul)
211-
USE_TRT_CONVERTER(fc)

0 commit comments

Comments
 (0)