Skip to content

Commit 1622cb9

Browse files
committed
Fix alpha tensor key
1 parent a8c077d commit 1622cb9

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ class LeakyReluOpConverter : public OpConverter {
7272
nvinfer1::ElementWiseOperation::kSUM);
7373
PADDLE_ENFORCE(nullptr != output_layer);
7474
// keep alpha tensor to avoid release it's memory
75-
engine_->weight_map[op_desc.Input("alpha")[0]] = std::move(alpha_tensor);
75+
std::string alpha_name = op_desc.Output("Out")[0] + "_alpha";
76+
PADDLE_ENFORCE(engine_->weight_map.find(alpha_name) ==
77+
engine_->weight_map.end());
78+
engine_->weight_map[alpha_name] = std::move(alpha_tensor);
7679

7780
std::string layer_name = "leaky_relu (Output: ";
7881
auto output_name = op_desc.Output("Out")[0];

paddle/fluid/inference/tensorrt/convert/test_leaky_relu_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ namespace paddle {
2020
namespace inference {
2121
namespace tensorrt {
2222

23-
TEST(leaky_relu_op, test_channel_wise) {
24-
std::unordered_set<std::string> parameters({"leaky_relu_alpha"});
23+
TEST(leaky_relu_op, test_leaky_relu) {
24+
std::unordered_set<std::string> parameters;
2525
framework::Scope scope;
2626
TRTConvertValidation validator(10, parameters, scope, 1000);
2727
validator.DeclInputVar("leaky_relu_input", nvinfer1::DimsCHW(3, 2, 2));

0 commit comments

Comments
 (0)