Skip to content

Commit b6c3b69

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix-beam-search-size
test=develop
2 parents 5dfce93 + 46a6cac commit b6c3b69

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1037
-89
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None
325325
paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name', 'axis'], varargs=None, keywords=None, defaults=('encode_center_size', True, None, 0))
326326
paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,))
327327
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,))
328+
paddle.fluid.layers.box_clip ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,))
328329
paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None))
329330
paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None))
330331
paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1))

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ pass_library(conv_elementwise_add2_act_fuse_pass inference)
6565
pass_library(conv_elementwise_add_fuse_pass inference)
6666
pass_library(conv_affine_channel_fuse_pass inference)
6767
pass_library(transpose_flatten_concat_fuse_pass inference)
68+
pass_library(identity_scale_op_clean_pass base)
6869

6970
# There may be many transpose-flatten structures in a model, and the output of
7071
# these structures will be used as inputs to the concat Op. This pattern will
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/identity_scale_op_clean_pass.h"
16+
#include <string>
17+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace ir {
22+
23+
std::unique_ptr<ir::Graph> IdentityScaleOpCleanPass::ApplyImpl(
24+
std::unique_ptr<ir::Graph> graph) const {
25+
FusePassBase::Init("identity_scale_op_clean", graph.get());
26+
27+
// pre_op -> scale_in -> scale_op -> scale_out
28+
// ->
29+
// pre_op -> scale_out
30+
GraphPatternDetector detector;
31+
auto pre_op = detector.mutable_pattern()->NewNode("pre_op")->assert_is_op();
32+
auto scale_in = detector.mutable_pattern()
33+
->NewNode("scale_in")
34+
->assert_is_op_input("scale")
35+
->AsIntermediate();
36+
auto scale_op = detector.mutable_pattern()
37+
->NewNode("scale_fuse")
38+
->assert_is_op("scale")
39+
->assert_op_attr<float>("scale", 1.)
40+
->assert_op_attr<float>("bias", 0.);
41+
auto scale_out = detector.mutable_pattern()
42+
->NewNode("scale_out")
43+
->assert_is_op_output("scale");
44+
45+
pre_op->LinksTo({scale_in});
46+
scale_op->LinksFrom({scale_in}).LinksTo({scale_out});
47+
48+
GraphPatternDetector::handle_t handler = [&](
49+
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
50+
Node* scale_op_var = subgraph.at(scale_op);
51+
Node* scale_in_var = subgraph.at(scale_in);
52+
Node* scale_out_var = subgraph.at(scale_out);
53+
Node* pre_op_var = subgraph.at(pre_op);
54+
// Link pre_op directly to scale_out
55+
const std::string scale_in_name = scale_in_var->Name();
56+
const std::string scale_out_name = scale_out_var->Name();
57+
// Remove links in graph
58+
GraphSafeRemoveNodes(graph, {scale_in_var, scale_op_var});
59+
// Modify proto message
60+
auto* pre_op_desc = pre_op_var->Op();
61+
for (auto& parameter : *pre_op_desc->Proto()->mutable_outputs()) {
62+
auto* arguments = parameter.mutable_arguments();
63+
auto it = std::find(arguments->begin(), arguments->end(), scale_in_name);
64+
PADDLE_ENFORCE(it != arguments->end());
65+
*it = scale_out_name;
66+
}
67+
68+
IR_NODE_LINK_TO(pre_op_var, scale_out_var);
69+
};
70+
71+
detector(graph.get(), handler);
72+
return graph;
73+
}
74+
75+
} // namespace ir
76+
} // namespace framework
77+
} // namespace paddle
78+
79+
REGISTER_PASS(identity_scale_op_clean_pass,
80+
paddle::framework::ir::IdentityScaleOpCleanPass);
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace ir {
22+
23+
class IdentityScaleOpCleanPass : public FusePassBase {
24+
protected:
25+
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
26+
27+
private:
28+
virtual ~IdentityScaleOpCleanPass() = default;
29+
};
30+
31+
} // namespace ir
32+
} // namespace framework
33+
} // namespace paddle

paddle/fluid/framework/scope.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,7 @@ limitations under the License. */
2222
#include "paddle/fluid/framework/threadpool.h"
2323
#include "paddle/fluid/string/printf.h"
2424

25-
DEFINE_bool(benchmark, false,
26-
"Doing memory benchmark. It will make deleting scope synchronized, "
27-
"and add some memory usage logs."
28-
"Default cuda is asynchronous device, set to True will"
29-
"force op run in synchronous mode.");
25+
DECLARE_bool(benchmark);
3026

3127
DEFINE_bool(
3228
eager_delete_scope, true,

paddle/fluid/inference/analysis/ir_pass_manager.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ void IRPassManager::CreatePasses(Argument *argument,
8383
new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir)));
8484
}
8585

86-
// graph_ = pass->Apply(std::move(graph_));
8786
pre_pass = pass_name;
8887

8988
passes_.emplace_back(std::move(pass));
@@ -97,8 +96,9 @@ std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) {
9796
PADDLE_ENFORCE(graph.get());
9897
// Apply all the passes
9998
for (const auto &pass : passes_) {
100-
if (pass->Type() == "graph_viz_pass") continue;
101-
PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type());
99+
if (pass->Type() != "graph_viz_pass") {
100+
PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type());
101+
}
102102
graph = pass->Apply(std::move(graph));
103103
}
104104
return std::move(graph);

paddle/fluid/inference/api/analysis_config.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,4 +318,9 @@ NativeConfig AnalysisConfig::ToNativeConfig() const {
318318
return config;
319319
}
320320

321+
void AnalysisConfig::SwitchIrDebug(int x) {
322+
ir_debug_ = x;
323+
Update();
324+
}
325+
321326
} // namespace paddle

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ namespace {
5858
bool IsPersistable(const framework::VarDesc *var) {
5959
if (var->Persistable() &&
6060
var->GetType() != framework::proto::VarType::FEED_MINIBATCH &&
61-
var->GetType() != framework::proto::VarType::FETCH_LIST) {
61+
var->GetType() != framework::proto::VarType::FETCH_LIST &&
62+
var->GetType() != framework::proto::VarType::RAW) {
6263
return true;
6364
}
6465
return false;

paddle/fluid/inference/api/analysis_predictor_tester.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ TEST(AnalysisPredictor, memory_optim) {
196196
AnalysisConfig config(FLAGS_dirname);
197197
config.DisableGpu();
198198
config.EnableMemoryOptim(true);
199-
config.pass_builder()->TurnOnDebug();
199+
config.SwitchIrDebug();
200200

201201
auto native_predictor =
202202
CreatePaddlePredictor<NativeConfig>(config.ToNativeConfig());

paddle/fluid/inference/api/paddle_analysis_config.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,12 @@ struct AnalysisConfig {
140140
*/
141141
bool tensorrt_engine_enabled() const { return use_tensorrt_; }
142142

143-
/** Control whther to debug IR graph analysis phase.
143+
/** \brief Control whether to debug IR graph analysis phase.
144+
*
145+
* This will generate DOT files for visualizing the computation graph after
146+
* each analysis pass applied.
144147
*/
145-
void SwitchIrDebug(int x = true) { ir_debug_ = x; }
148+
void SwitchIrDebug(int x = true);
146149

147150
/** Turn on MKLDNN.
148151
*/

0 commit comments

Comments
 (0)