Skip to content

Commit 897789b

Browse files
authored
fix save_inferece_model bug (#15365)
1 parent ba02ac4 commit 897789b

File tree

6 files changed

+131
-2
lines changed

6 files changed

+131
-2
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ pass_library(conv_elementwise_add2_act_fuse_pass inference)
6565
pass_library(conv_elementwise_add_fuse_pass inference)
6666
pass_library(conv_affine_channel_fuse_pass inference)
6767
pass_library(transpose_flatten_concat_fuse_pass inference)
68+
pass_library(identity_scale_op_clean_pass base)
6869

6970
# There may be many transpose-flatten structures in a model, and the output of
7071
# these structures will be used as inputs to the concat Op. This pattern will
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/identity_scale_op_clean_pass.h"
16+
#include <string>
17+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace ir {
22+
23+
std::unique_ptr<ir::Graph> IdentityScaleOpCleanPass::ApplyImpl(
24+
std::unique_ptr<ir::Graph> graph) const {
25+
FusePassBase::Init("identity_scale_op_clean", graph.get());
26+
27+
// pre_op -> scale_in -> scale_op -> scale_out
28+
// ->
29+
// pre_op -> scale_out
30+
GraphPatternDetector detector;
31+
auto pre_op = detector.mutable_pattern()->NewNode("pre_op")->assert_is_op();
32+
auto scale_in = detector.mutable_pattern()
33+
->NewNode("scale_in")
34+
->assert_is_op_input("scale")
35+
->AsIntermediate();
36+
auto scale_op = detector.mutable_pattern()
37+
->NewNode("scale_fuse")
38+
->assert_is_op("scale")
39+
->assert_op_attr<float>("scale", 1.)
40+
->assert_op_attr<float>("bias", 0.);
41+
auto scale_out = detector.mutable_pattern()
42+
->NewNode("scale_out")
43+
->assert_is_op_output("scale");
44+
45+
pre_op->LinksTo({scale_in});
46+
scale_op->LinksFrom({scale_in}).LinksTo({scale_out});
47+
48+
GraphPatternDetector::handle_t handler = [&](
49+
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
50+
Node* scale_op_var = subgraph.at(scale_op);
51+
Node* scale_in_var = subgraph.at(scale_in);
52+
Node* scale_out_var = subgraph.at(scale_out);
53+
Node* pre_op_var = subgraph.at(pre_op);
54+
// Link pre_op directly to scale_out
55+
const std::string scale_in_name = scale_in_var->Name();
56+
const std::string scale_out_name = scale_out_var->Name();
57+
// Remove links in graph
58+
GraphSafeRemoveNodes(graph, {scale_in_var, scale_op_var});
59+
// Modify proto message
60+
auto* pre_op_desc = pre_op_var->Op();
61+
for (auto& parameter : *pre_op_desc->Proto()->mutable_outputs()) {
62+
auto* arguments = parameter.mutable_arguments();
63+
auto it = std::find(arguments->begin(), arguments->end(), scale_in_name);
64+
PADDLE_ENFORCE(it != arguments->end());
65+
*it = scale_out_name;
66+
}
67+
68+
IR_NODE_LINK_TO(pre_op_var, scale_out_var);
69+
};
70+
71+
detector(graph.get(), handler);
72+
return graph;
73+
}
74+
75+
} // namespace ir
76+
} // namespace framework
77+
} // namespace paddle
78+
79+
REGISTER_PASS(identity_scale_op_clean_pass,
80+
paddle::framework::ir::IdentityScaleOpCleanPass);
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace ir {
22+
23+
class IdentityScaleOpCleanPass : public FusePassBase {
24+
protected:
25+
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
26+
27+
private:
28+
virtual ~IdentityScaleOpCleanPass() = default;
29+
};
30+
31+
} // namespace ir
32+
} // namespace framework
33+
} // namespace paddle

paddle/fluid/inference/api/paddle_pass_builder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ class CpuPassStrategy : public PassStrategy {
117117
"conv_bn_fuse_pass", //
118118
"conv_eltwiseadd_bn_fuse_pass", //
119119
"is_test_pass", //
120+
"identity_scale_op_clean_pass", //
120121
});
121122
use_gpu_ = false;
122123
}
@@ -155,6 +156,7 @@ class GpuPassStrategy : public PassStrategy {
155156
GpuPassStrategy() : PassStrategy({}) {
156157
passes_.assign({
157158
"infer_clean_graph_pass", //
159+
"identity_scale_op_clean_pass", //
158160
"conv_affine_channel_fuse_pass", //
159161
"conv_eltwiseadd_affine_channel_fuse_pass", //
160162
"conv_bn_fuse_pass", //

python/paddle/fluid/io.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
import six
2222
from functools import reduce
2323

24+
from paddle.fluid import layers
2425
from paddle.fluid.executor import Executor
2526
from paddle.fluid.evaluator import Evaluator
26-
from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable
27+
from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable, program_guard
2728
from . import core
2829

2930
__all__ = [
@@ -931,6 +932,17 @@ def save_inference_model(dirname,
931932
if main_program is None:
932933
main_program = default_main_program()
933934

935+
# fix the bug that the activation op's output as target will be pruned.
936+
# will affect the inference performance.
937+
# TODO(Superjomn) add an IR pass to remove 1-scale op.
938+
with program_guard(main_program):
939+
uniq_target_vars = []
940+
for var in target_vars:
941+
if isinstance(var, Variable):
942+
var1 = layers.scale(var, 1.)
943+
uniq_target_vars.append(var1)
944+
target_vars = uniq_target_vars
945+
934946
# when a pserver and a trainer running on the same machine, mkdir may conflict
935947
try:
936948
os.makedirs(dirname)

python/paddle/fluid/tests/unittests/test_inference_model_io.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def test_fit_line_inference_model(self):
8282

8383
self.assertEqual(feed_var_names, ["x", "y"])
8484
self.assertEqual(len(fetch_vars), 1)
85-
self.assertEqual(str(fetch_vars[0]), str(avg_cost))
85+
print("fetch %s" % str(fetch_vars[0]))
86+
self.assertTrue("scale" in str(fetch_vars[0]))
8687
self.assertEqual(expected, actual)
8788

8889

0 commit comments

Comments
 (0)