Skip to content

Commit f6fb51a

Browse files
committed
add test_mode in trt/activation_op
1 parent c73977a commit f6fb51a

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

paddle/fluid/inference/tensorrt/convert/activation_op.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class ReluOpConverter : public OpConverter {
2323
public:
2424
ReluOpConverter() {}
2525
void operator()(const framework::proto::OpDesc& op,
26-
const framework::Scope& scope) override {
26+
const framework::Scope& scope, bool test_mode) override {
2727
// Here the two nullptr looks strange, that's because the
2828
// framework::OpDesc's constructor is strange.
2929
framework::OpDesc op_desc(op, nullptr);
@@ -34,7 +34,12 @@ class ReluOpConverter : public OpConverter {
3434
nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER(
3535
engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor),
3636
nvinfer1::ActivationType::kRELU);
37-
engine_->DeclareOutput(layer, 0, op_desc.Output("Out")[0]);
37+
auto output_name = op_desc.Output("Out")[0];
38+
engine_->SetITensor(output_name, layer->getOutput(0));
39+
if (test_mode) { // the test framework can not determine which is the
40+
// output, so place the declaration inside.
41+
engine_->DeclareOutput(output_name);
42+
}
3843
}
3944
};
4045

0 commit comments

Comments
 (0)