Skip to content

Commit 6544cb4

Browse files
authored
Merge pull request #15781 from dzhwinter/test/picked
cherry-pick memory optimize changes to release
2 parents fcdc623 + b2eb623 commit 6544cb4

23 files changed

+1135
-952
lines changed

cmake/flags.cmake

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ function(CheckCompilerCXX11Flag)
2121
if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.3)
2222
message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.3 required.")
2323
endif()
24-
endif()
24+
endif()
2525
endif()
2626
endfunction()
2727

@@ -147,6 +147,7 @@ set(GPU_COMMON_FLAGS
147147
-Wno-error=unused-function # Warnings in Numpy Header.
148148
-Wno-error=array-bounds # Warnings in Eigen::array
149149
)
150+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -m64")
150151
endif(NOT WIN32)
151152

152153
if (APPLE)

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,15 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
5050
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
5151
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
5252

53-
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper)
53+
if(WITH_GPU)
54+
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper gpu_info)
55+
else()
56+
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper cpu_info)
57+
endif()
58+
5459
cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass)
5560
cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info)
5661
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
57-
cc_library(memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle
58-
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
5962
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle)
6063
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
6164
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass)
@@ -67,13 +70,11 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he
6770
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
6871
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
6972

70-
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass memory_early_delete_pass inplace_op_pass)
73+
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass)
7174
if (WITH_GPU)
7275
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
7376
endif()
74-
cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph)
75-
cc_test(memory_optimize_pass_test SRCS memory_optimize_pass_test.cc memory_optimize_pass.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry pass)
76-
77+
cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry)
7778
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
7879

7980
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope

paddle/fluid/framework/details/build_strategy.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
206206
new std::vector<OpDesc *>(main_program.Block(0).AllOps());
207207
graph->Set<const std::vector<OpDesc *>>(kAllOpDescs,
208208
all_op_descs); // take ownership
209-
graph->Set<GraphNodePool>(kGraphNodePool,
210-
new GraphNodePool); // take ownership
211209

212210
pass->Erase(kAllOpDescs);
213211
pass->SetNotOwned<const std::vector<OpDesc *>>(kAllOpDescs, all_op_descs);
@@ -242,7 +240,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
242240
continue;
243241
}
244242
}
243+
VLOG(3) << "Start Apply Pass " << pass->Type();
245244
graph = pass->Apply(std::move(graph));
245+
VLOG(3) << "Finish Apply Pass " << pass->Type();
246246
}
247247
return graph;
248248
}

paddle/fluid/framework/details/inplace_op_pass.cc

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ DEFINE_bool(
4949
"If this option turns on, only these op in whitelist can be inplaced."
5050
"If it turns off, all of the running op can be candidate of inplaced op."
5151
"Such as scale, elementwise_add"
52-
"By default, it's turned on");
52+
"By default, it's turned off");
5353

5454
DECLARE_string(memory_optimize_debug);
5555

@@ -171,16 +171,15 @@ void InplacePass::InplaceModifyDesc(const std::string& var,
171171
}
172172
}
173173

174-
const SSANodePair InplacePass::TryInplaceModifyVar(const std::string& var,
175-
const std::string& cache_var,
176-
const size_t& idx,
177-
ir::Graph* graph) const {
174+
const NodeSwapQueue InplacePass::TryInplaceModifyVar(
175+
const std::string& var, const std::string& cache_var, const size_t& idx,
176+
ir::Graph* graph) const {
178177
PADDLE_ENFORCE(var_nodes_[var].size() >= 1 &&
179178
var_nodes_[var].at(0)->Var() != nullptr);
180179
std::unique_ptr<VarDesc> var_desc(new VarDesc(*var_nodes_[var].at(0)->Var()));
181180
var_desc->SetName(cache_var);
182181

183-
SSANodePair swap_nodes;
182+
NodeSwapQueue swap_nodes;
184183

185184
for (size_t i = idx; i < view_.AllOps().size(); ++i) {
186185
auto* op = view_.AllOps()[i];
@@ -230,7 +229,7 @@ const SSANodePair InplacePass::TryInplaceModifyVar(const std::string& var,
230229
return swap_nodes;
231230
}
232231

233-
void InplacePass::CommitModify(const SSANodePair& swap_nodes,
232+
void InplacePass::CommitModify(const NodeSwapQueue& swap_nodes,
234233
ir::Graph* graph) const {
235234
for (auto& pair : swap_nodes) {
236235
auto *node = pair.first, *cache_node = pair.second;
@@ -245,7 +244,7 @@ void InplacePass::CommitModify(const SSANodePair& swap_nodes,
245244
}
246245
}
247246

248-
void InplacePass::WithdrawModify(const SSANodePair& nodes,
247+
void InplacePass::WithdrawModify(const NodeSwapQueue& nodes,
249248
ir::Graph* graph) const {
250249
for (auto& pair : nodes) {
251250
auto *node = pair.first, *cache_node = pair.second;
@@ -403,18 +402,20 @@ void GraphView::Build(ir::Graph* g) {
403402
// 2. track the nodes which used by parameter server.
404403
// these node can not be inplaced, otherwise trainer
405404
// pserver can not find each other name.
406-
for (auto& node : g->Nodes()) {
407-
if (!node->IsOp()) continue;
408-
if (node->Name() == "send") {
409-
for (auto& in : node->inputs) {
410-
dup_nodes_.emplace(in->Name());
411-
}
405+
auto update_skip_set = [&](ir::Node* node) {
406+
for (auto& in : node->inputs) {
407+
if (in->IsVar() && in->Var() != nullptr) dup_nodes_.emplace(in->Name());
412408
}
413-
if (node->Name() == "recv") {
414-
for (auto& out : node->outputs) {
409+
for (auto& out : node->outputs) {
410+
if (out->IsVar() && out->Var() != nullptr)
415411
dup_nodes_.emplace(out->Name());
416-
}
417412
}
413+
};
414+
for (auto& node : g->Nodes()) {
415+
if (!node->IsOp()) continue;
416+
if (node->Name() == "send") update_skip_set(node);
417+
if (node->Name() == "recv") update_skip_set(node);
418+
if (node->Name() == "prefetch") update_skip_set(node);
418419
}
419420
}
420421

paddle/fluid/framework/details/inplace_op_pass.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ class GraphView {
5656
std::map<ir::Node*, std::unordered_set<ir::Node*>> adj_list_;
5757
};
5858

59-
typedef std::vector<std::pair<ir::Node*, ir::Node*>> SSANodePair;
59+
// swap pairs in sequence
60+
typedef std::vector<std::pair<ir::Node*, ir::Node*>> NodeSwapQueue;
6061
class InplacePass : public ir::Pass {
6162
public:
6263
InplacePass();
@@ -68,14 +69,14 @@ class InplacePass : public ir::Pass {
6869
void InitSSAGraphNodes() const;
6970

7071
private:
71-
const SSANodePair TryInplaceModifyVar(const std::string& var,
72-
const std::string& cache_var,
73-
const size_t& idx,
74-
ir::Graph* graph) const;
72+
const NodeSwapQueue TryInplaceModifyVar(const std::string& var,
73+
const std::string& cache_var,
74+
const size_t& idx,
75+
ir::Graph* graph) const;
7576

76-
void CommitModify(const SSANodePair&, ir::Graph* graph) const;
77+
void CommitModify(const NodeSwapQueue&, ir::Graph* graph) const;
7778

78-
void WithdrawModify(const SSANodePair& nodes, ir::Graph* graph) const;
79+
void WithdrawModify(const NodeSwapQueue& nodes, ir::Graph* graph) const;
7980

8081
void InplaceModifyDesc(const std::string& in_var, const std::string& out_var,
8182
const size_t& idx) const;

0 commit comments

Comments
 (0)