Skip to content

Commit 74bc55c

Browse files
authored
Merge pull request #14975 from dzhwinter/ir_inplace_pass
Ir inplace pass
2 parents 546eefa + 9f001c6 commit 74bc55c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1647
-304
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ cc_test(version_test SRCS version_test.cc DEPS version)
128128

129129
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
130130

131-
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
131+
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc memory_optimize_helper)
132132
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
133133

134134
py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto)
@@ -192,6 +192,7 @@ cc_library(prune SRCS prune.cc DEPS framework_proto)
192192
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
193193
cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
194194
proto_desc)
195+
cc_test(inplace_op_inference_test SRCS inplace_op_inference_test.cc DEPS op_registry proto_desc op_info memory_optimize_helper)
195196
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
196197
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
197198

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
5050
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
5151
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
5252

53-
cc_library(memory_optimize_pass SRCS analysis_var_pass.cc memory_reuse_types.cc DEPS graph graph_helper pass)
53+
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper)
54+
cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass)
55+
cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info)
5456
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
5557
cc_library(memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle
5658
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
@@ -65,12 +67,12 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he
6567
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
6668
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
6769

68-
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass memory_early_delete_pass)
70+
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass memory_early_delete_pass inplace_op_pass)
6971
if (WITH_GPU)
7072
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
7173
endif()
72-
cc_test(memory_reuse_types_test SRCS memory_reuse_types_test.cc memory_reuse_types.cc DEPS framework_proto graph)
73-
cc_test(analysis_var_pass_test SRCS analysis_var_pass_test.cc analysis_var_pass.cc memory_reuse_types.cc DEPS framework_proto graph graph_helper op_registry pass)
74+
cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph)
75+
cc_test(memory_optimize_pass_test SRCS memory_optimize_pass_test.cc memory_optimize_pass.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry pass)
7476

7577
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
7678

paddle/fluid/framework/details/build_strategy.cc

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ limitations under the License. */
1717
#include <glog/logging.h>
1818
#include <memory>
1919

20-
#include "paddle/fluid/framework/details/memory_reuse_types.h"
20+
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
2121
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
2222
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
2323
#include "paddle/fluid/framework/details/reduce_op_handle.h"
@@ -47,6 +47,22 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
4747
AppendPass("sequential_execution_pass");
4848
}
4949

50+
// Add op fusion.
51+
if (strategy.fuse_relu_depthwise_conv_) {
52+
AppendPass("fuse_relu_depthwise_conv_pass");
53+
}
54+
55+
// NOTE(dzhwinter): A note for automatical inplace.
56+
// 1. modify program desc passes should put
57+
// before inplace pass.
58+
// 2. manually configured inplace should put
59+
// before inplace_pass
60+
61+
// Add automatically inplace.
62+
if (strategy_.enable_inplace_) {
63+
AppendPass("inplace_pass");
64+
}
65+
5066
// Add a graph viz pass to record a graph.
5167
if (!strategy_.debug_graphviz_path_.empty()) {
5268
auto viz_pass = AppendPass("graph_viz_pass");
@@ -55,10 +71,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
5571
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
5672
}
5773

58-
// Add op fusion.
59-
if (strategy.fuse_relu_depthwise_conv_) {
60-
AppendPass("fuse_relu_depthwise_conv_pass");
61-
}
6274
if (strategy.fuse_elewise_add_act_ops_) {
6375
auto fuse_elewise_add_act_pass = AppendPass("fuse_elewise_add_act_pass");
6476
// Add a graph viz pass to record a graph.
@@ -88,7 +100,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
88100
// A side-effect of that, memory optimize cannot forsee the fetched vars
89101
// , so fetchlist should be set persistable before call the Run interface.
90102
if (strategy.memory_optimize_) {
91-
auto analysis_var_pass = AppendPass("analysis_var_pass");
103+
auto memory_optimize_pass = AppendPass("memory_optimize_pass");
92104
}
93105

94106
AppendMultiDevPass(strategy);
@@ -186,8 +198,10 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
186198
pass->Erase("nccl_ctxs");
187199
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
188200
#endif
189-
190-
} else if (pass->Type() == "analysis_var_pass") {
201+
} else if (pass->Type() == "memory_optimize_pass") {
202+
if (graph->Has(kAllOpDescs)) {
203+
graph->Erase(kAllOpDescs);
204+
}
191205
const std::vector<OpDesc *> *all_op_descs =
192206
new std::vector<OpDesc *>(main_program.Block(0).AllOps());
193207
graph->Set<const std::vector<OpDesc *>>(kAllOpDescs,
@@ -214,6 +228,13 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
214228
pass->Set<const std::vector<OpDesc *>>(
215229
kAllOpDescs,
216230
new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
231+
} else if (pass->Type() == "inplace_pass") {
232+
if (graph->Has(kAllOpDescs)) {
233+
graph->Erase(kAllOpDescs);
234+
}
235+
graph->Set<const std::vector<OpDesc *>>(
236+
kAllOpDescs,
237+
new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
217238
} else if (pass->Type() == "fuse_relu_depthwise_conv_pass") {
218239
if (!use_cuda) {
219240
LOG(WARNING) << "fuse_relu_depthwise_conv_pass is only supported on "
@@ -239,9 +260,10 @@ USE_PASS(allreduce_mode_multi_devices_pass);
239260
USE_PASS(dist_multi_devices_pass);
240261
USE_PASS(multi_devices_check_pass);
241262
USE_PASS(multi_devices_print_pass);
242-
USE_PASS(analysis_var_pass);
263+
USE_PASS(memory_optimize_pass);
243264
USE_PASS(sequential_execution_pass);
244265
USE_PASS(all_reduce_deps_pass);
245266
USE_PASS(modify_op_lock_and_record_event_pass);
267+
USE_PASS(inplace_pass);
246268
USE_PASS(lock_free_optimize_pass);
247269
USE_PASS(graph_to_program_pass);

paddle/fluid/framework/details/build_strategy.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ struct BuildStrategy {
8080

8181
bool memory_early_delete_{false};
8282

83+
// TODO(dzhwinter):
84+
// make enable_inplace, memory_optimize_
85+
// memory_early_delete_ true by default
86+
bool enable_inplace_{false};
87+
8388
bool enable_sequential_execution_{false};
8489

8590
bool fuse_broadcast_op_{false};
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <algorithm>
18+
#include <iostream>
19+
#include <iterator>
20+
#include <string>
21+
#include "glog/logging.h"
22+
#include "gtest/gtest.h"
23+
#include "paddle/fluid/framework/ir/graph.h"
24+
#include "paddle/fluid/framework/ir/graph_helper.h"
25+
#include "paddle/fluid/framework/op_registry.h"
26+
#include "paddle/fluid/framework/program_desc.h"
27+
28+
namespace paddle {
29+
namespace framework {
30+
31+
class DummyOp : public OperatorBase {
32+
public:
33+
DummyOp(const std::string& type, const VariableNameMap& inputs,
34+
const VariableNameMap& outputs, const AttributeMap& attrs)
35+
: OperatorBase(type, inputs, outputs, attrs) {}
36+
37+
private:
38+
void RunImpl(const Scope& scope,
39+
const platform::Place& place) const override {}
40+
};
41+
42+
class SumOpMaker : public OpProtoAndCheckerMaker {
43+
public:
44+
void Make() {
45+
AddInput("X", "").AsDuplicable();
46+
AddOutput("Out", "");
47+
AddComment("");
48+
}
49+
};
50+
51+
class AssignOpMaker : public OpProtoAndCheckerMaker {
52+
public:
53+
void Make() {
54+
AddInput("X", "").AsDuplicable();
55+
AddOutput("Out", "");
56+
AddComment("");
57+
}
58+
};
59+
60+
class SplitOpMaker : public OpProtoAndCheckerMaker {
61+
public:
62+
void Make() {
63+
AddInput("X", "");
64+
AddOutput("Out", "").AsDuplicable();
65+
AddComment("");
66+
}
67+
};
68+
69+
class DummyVarTypeInference : public VarTypeInference {
70+
public:
71+
void operator()(const OpDesc& op_desc, BlockDesc* block) const override {
72+
auto& inputs = op_desc.Input("X");
73+
auto type = block->Var(inputs.front())->GetType();
74+
auto out_var_name = op_desc.Output("Out").front();
75+
block->Var(out_var_name)->SetType(type);
76+
}
77+
};
78+
79+
} // namespace framework
80+
} // namespace paddle

0 commit comments

Comments
 (0)