Skip to content

Commit 5ab96d3

Browse files
authored
add warning and skip vars to mem opt passes (#16967)
test=release/1.4
1 parent 64c1427 commit 5ab96d3

14 files changed

+202
-35
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_
1515
cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
1616
cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
1717

18+
cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper)
19+
1820
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
1921

2022
if(WITH_DISTRIBUTE)
@@ -114,4 +116,4 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
114116
fuse_relu_depthwise_conv_pass
115117
memory_optimize_pass lock_free_optimize_pass
116118
alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass
117-
fuse_adam_op_pass fuse_sgd_op_pass)
119+
fuse_adam_op_pass fuse_sgd_op_pass record_skip_memory_opt_vars_pass)

paddle/fluid/framework/details/build_strategy.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
5353
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
5454
}
5555

56+
// Note(zcd): record_skip_memory_opt_vars_pass should be the first pass.
57+
AppendPass("record_skip_memory_opt_vars_pass");
58+
5659
if (strategy_.enable_sequential_execution_) {
5760
VLOG(10) << "Add sequential_execution_pass";
5861
AppendPass("sequential_execution_pass");
@@ -320,3 +323,4 @@ USE_PASS(graph_to_program_pass);
320323
USE_PASS(fuse_adam_op_pass);
321324
USE_PASS(fuse_sgd_op_pass);
322325
USE_PASS(fuse_all_reduce_op_pass);
326+
USE_PASS(record_skip_memory_opt_vars_pass);

paddle/fluid/framework/details/inplace_op_pass.cc

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,16 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
303303
auto* in_node = view_.GetNodeByName(in_var_name, op->inputs);
304304
auto* out_node = view_.GetNodeByName(out_var_name, op->outputs);
305305

306-
VLOG(4) << "Try to inplace " << in_var_name << " with " << out_var_name;
306+
VLOG(4) << "Try to replace: " << in_var_name << " => " << out_var_name;
307+
if (view_.InSkipSet(in_var_name)) {
308+
VLOG(4) << string::Sprintf("SKIP: %s is in skip set", in_var_name);
309+
continue;
310+
}
311+
312+
if (view_.InSkipSet(out_var_name)) {
313+
VLOG(4) << string::Sprintf("SKIP: %s is in skip set", out_var_name);
314+
continue;
315+
}
307316

308317
if (var_nodes_[in_var_name].back() != in_node) {
309318
VLOG(4) << "SKIP since " << in_var_name
@@ -318,11 +327,15 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
318327
<< out_var_name << " are the same";
319328
} else if (!NodeCanReused(in_node)) {
320329
can_replace = false;
321-
VLOG(4) << "SKIP: Input varialbe " << in_var_name << "cannot be reused";
330+
VLOG(4) << "SKIP: Input variable " << in_var_name << "cannot be reused";
322331
} else if (!NodeCanReused(out_node)) {
323332
can_replace = false;
324333
VLOG(4) << "SKIP: Output variable " << out_var_name
325334
<< " cannot be reused";
335+
} else if (in_node->Var()->GetType() != out_node->Var()->GetType()) {
336+
can_replace = false;
337+
VLOG(4) << "SKIP: Input type : " << in_node->Var()->GetType()
338+
<< " does not match Output type : " << out_node->Var()->GetType();
326339
} else if (details::NodeSize(*in_node->Var()) !=
327340
details::NodeSize(*out_node->Var())) {
328341
can_replace = false;
@@ -331,8 +344,8 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
331344

332345
if (!can_replace) continue;
333346

334-
// 2. there is no external pending op on the input node
335-
// if (view_.PendingOpsOnVar(in_node).size() > 1) {
347+
// 2. If the variable is the input of muliple ops, we need to make sure
348+
// current op has dependecny on other ops use the same variable
336349
if (in_node->outputs.size() > 1 && !view_.CheckDeps(in_node, op)) {
337350
VLOG(4) << string::Sprintf(
338351
"Skiped pair %s => %s. %s input has external dependency."
@@ -341,17 +354,6 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
341354
continue;
342355
}
343356

344-
// 3. if output has been memory optimize by python(fluid.memory_optmize()).
345-
// this candidate can not be inplaced. Will be deprecated in the future.
346-
if (view_.InSkipSet(out_node->Name())) {
347-
VLOG(4) << string::Sprintf(
348-
"Skiped %s => %s reused previous memory block in python memory "
349-
"optmize,"
350-
"it inplace may generate a circle",
351-
out_var_name, in_var_name, op->Name());
352-
continue;
353-
}
354-
355357
// Debug Interface. Which would be skipped by the pass.
356358
if (out_node->Name() == FLAGS_memory_optimize_debug) {
357359
VLOG(3) << "Skiped var by force. FLAGS_memory_optimize_debug="
@@ -519,16 +521,22 @@ void GraphView::Build(ir::Graph* g) {
519521
// resolve data harzards depends on the var nodes in right order.
520522
TopoSort(g);
521523

524+
// fill the skip_set_
525+
PADDLE_ENFORCE(g->Has(details::kMemOptSkipVars));
526+
auto& mem_opt_whitelist = g->Get<MemOptSkipVars>(kMemOptSkipVars);
527+
for (const auto& var : mem_opt_whitelist) skip_set_.emplace(var);
528+
522529
// 2. track the nodes which used by parameter server.
523530
// these node can not be inplaced, otherwise trainer
524531
// pserver can not find each other name.
525532
auto update_skip_set = [&](ir::Node* node) {
526533
for (auto& in : node->inputs) {
527-
if (in->IsVar() && in->Var() != nullptr) dup_nodes_.emplace(in->Name());
534+
if (in->IsVar() && in->Var() != nullptr) {
535+
skip_set_.emplace(in->Name());
536+
}
528537
}
529538
for (auto& out : node->outputs) {
530-
if (out->IsVar() && out->Var() != nullptr)
531-
dup_nodes_.emplace(out->Name());
539+
if (out->IsVar() && out->Var() != nullptr) skip_set_.emplace(out->Name());
532540
}
533541
};
534542
for (auto& node : g->Nodes()) {
@@ -545,7 +553,7 @@ void GraphView::Build(ir::Graph* g) {
545553
const std::vector<ir::Node*>& GraphView::AllOps() { return ops_; }
546554

547555
bool GraphView::InSkipSet(const std::string& var) const {
548-
return dup_nodes_.count(var);
556+
return skip_set_.count(var);
549557
}
550558

551559
} // namespace details

paddle/fluid/framework/details/inplace_op_pass.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class GraphView {
5757

5858
private:
5959
std::vector<ir::Node*> ops_;
60-
std::unordered_set<std::string> dup_nodes_; // mem opt affect nodes
60+
std::unordered_set<std::string> skip_set_; // mem opt affect nodes
6161
std::map<ir::Node*, std::unordered_set<ir::Node*>> adj_list_;
6262
std::unordered_map<ir::Node*, uint32_t> op_level_;
6363
};

paddle/fluid/framework/details/memory_optimize_helper.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <set>
2222
#include <string>
2323
#include <unordered_map>
24+
#include <unordered_set>
2425
#include <utility>
2526
#include <vector>
2627
#include "paddle/fluid/framework/data_type.h"
@@ -30,6 +31,11 @@ namespace paddle {
3031
namespace framework {
3132
namespace details {
3233

34+
/// this attribute is used to avoid some core variables removed/reused
35+
/// in memory optimize related passes
36+
constexpr char kMemOptSkipVars[] = "@MEM_OPT_SKIP_VARS@";
37+
typedef std::unordered_set<std::string> MemOptSkipVars;
38+
3339
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph);
3440

3541
// NOTE(dzh): A ordered set for node reuse in memory optimize.

paddle/fluid/framework/details/memory_optimize_pass.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ namespace framework {
4545
namespace details {
4646

4747
void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const {
48-
auto nodes = graph->Nodes();
49-
CollectSkipVarsSet(nodes);
48+
CollectSkipVarsSet(graph);
5049

5150
cfg_.reset(new details::ControlFlowGraph(*graph));
5251
cfg_->LiveVariableAnalysis();
@@ -204,14 +203,20 @@ void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
204203
}
205204
}
206205

207-
void MemoryOptimizePass::CollectSkipVarsSet(
208-
const std::unordered_set<ir::Node*>& nodes) const {
206+
void MemoryOptimizePass::CollectSkipVarsSet(ir::Graph* graph) const {
207+
// fill skip_set_
208+
PADDLE_ENFORCE(graph->Has(details::kMemOptSkipVars));
209+
auto& mem_opt_whitelist = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
210+
for (const auto& var : mem_opt_whitelist) skip_set_.emplace(var);
211+
209212
auto update_skip_set = [&](OpDesc* op_desc) {
210213
auto inputs = op_desc->InputArgumentNames();
211214
auto outputs = op_desc->OutputArgumentNames();
212215
skip_set_.insert(inputs.begin(), inputs.end());
213216
skip_set_.insert(outputs.begin(), outputs.end());
214217
};
218+
219+
auto nodes = graph->Nodes();
215220
for (auto& op : nodes) {
216221
if (!op->IsOp() || op->Op() == nullptr) continue;
217222
auto* op_desc = op->Op();

paddle/fluid/framework/details/memory_optimize_pass.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ class MemoryOptimizePass : public ir::Pass {
5353
// 1. scan op with subblock and collect the output/input vars.
5454
// while, while_grad, conditional_block
5555
// 2. scan distributed ops and collect the output/input vars
56-
void CollectSkipVarsSet(const std::unordered_set<ir::Node*>&) const;
56+
// 3. op_role_vars
57+
void CollectSkipVarsSet(ir::Graph* graph) const;
5758

5859
private:
5960
// Reuse Node Pool, Owned.
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <string>
16+
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
17+
#include "paddle/fluid/framework/ir/graph.h"
18+
#include "paddle/fluid/framework/ir/graph_helper.h"
19+
#include "paddle/fluid/framework/ir/pass.h"
20+
#include "paddle/fluid/framework/op_proto_maker.h"
21+
22+
namespace paddle {
23+
namespace framework {
24+
namespace details {
25+
26+
class RecordSkipMemoryOptVarsPass : public ir::Pass {
27+
protected:
28+
void ApplyImpl(ir::Graph* graph) const override {
29+
PADDLE_ENFORCE(!graph->Has(kMemOptSkipVars));
30+
graph->Set(kMemOptSkipVars, new MemOptSkipVars);
31+
auto& skip_vars = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
32+
33+
// NOTE(zcd): Insert OpRoleVars to SkipVarSet to prevent the vars are rename
34+
// in memory optimize pass.
35+
InsertOpRoleVarsToSkipVarSet(graph, &skip_vars);
36+
}
37+
38+
void InsertOpRoleVarsToSkipVarSet(const ir::Graph* graph,
39+
MemOptSkipVars* skip_vars) const {
40+
for (auto& node : graph->Nodes()) {
41+
PADDLE_ENFORCE_NOT_NULL(node, "The node should not be nullptr.");
42+
if (node->IsOp() && node->Op()) {
43+
try {
44+
auto op_role_vars =
45+
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
46+
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
47+
PADDLE_ENFORCE_EQ(op_role_vars.size() % 2, 0);
48+
for (size_t i = 0; i < op_role_vars.size(); i += 2) {
49+
auto& g_name = op_role_vars[i + 1];
50+
skip_vars->insert(g_name);
51+
}
52+
} catch (boost::bad_get e) {
53+
}
54+
}
55+
}
56+
}
57+
};
58+
59+
} // namespace details
60+
} // namespace framework
61+
} // namespace paddle
62+
63+
REGISTER_PASS(record_skip_memory_opt_vars_pass,
64+
paddle::framework::details::RecordSkipMemoryOptVarsPass);

paddle/fluid/framework/inplace_op_inference_test.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <vector>
2020
#include "gtest/gtest.h"
2121
#include "paddle/fluid/framework/details/inplace_op_pass.h"
22+
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
2223
#include "paddle/fluid/framework/ir/pass_builder.h"
2324
#include "paddle/fluid/framework/op_info.h"
2425
#include "paddle/fluid/framework/op_registry.h"
@@ -217,6 +218,7 @@ TEST(InferInplace, SingleOpInplaceInToOut) {
217218

218219
FakeSuccData(&prog);
219220
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
221+
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
220222
g = test_SingleOpInplaceInToOut(std::move(g));
221223
auto op_node = GetNodeFromGraph(g.get(), "single_op");
222224

@@ -232,6 +234,7 @@ TEST(InferInplace, SingleOpInplaceInToOutNoInplace) {
232234

233235
FakeNoInplaceData(&prog);
234236
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
237+
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
235238
g = test_SingleOpInplaceInToOut(std::move(g));
236239
auto op_node = GetNodeFromGraph(g.get(), "single_op");
237240

@@ -264,6 +267,7 @@ TEST(InferInplace, MultiOutInplaceInToOut) {
264267
prog.MutableBlock(0)->Var("z0")->SetShape({32, 16, 1024, 1024});
265268

266269
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
270+
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
267271
std::unique_ptr<details::InplacePass> pass(new details::InplacePass());
268272
pass->Apply(g.get());
269273
auto op_node = GetNodeFromGraph(g.get(), "multi_out_op");
@@ -299,6 +303,7 @@ TEST(InferInplace, MultiGradInplaceInToOut) {
299303
prog.MutableBlock(0)->Var("z0")->SetShape({32, 15, 1024, 1024});
300304

301305
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
306+
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
302307
std::unique_ptr<details::InplacePass> pass(new details::InplacePass());
303308
pass->Apply(g.get());
304309
auto op_node = GetNodeFromGraph(g.get(), "multi_out_grad");

paddle/fluid/pybind/const_value.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/pybind/const_value.h"
16+
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
1617
#include "paddle/fluid/framework/ir/node.h"
1718
#include "paddle/fluid/framework/op_proto_maker.h"
1819
#include "paddle/fluid/framework/operator.h"
@@ -28,6 +29,7 @@ void BindConstValue(pybind11::module* m) {
2829
m->def("kControlDepVarName",
2930
[] { return framework::ir::Node::kControlDepVarName; });
3031
m->def("kNewGradSuffix", [] { return framework::kNewGradSuffix; });
32+
m->def("kMemOptSkipVars", [] { return framework::details::kMemOptSkipVars; });
3133

3234
auto op_proto_and_checker_maker =
3335
m->def_submodule("op_proto_and_checker_maker");

0 commit comments

Comments
 (0)