Skip to content

Commit 6b45c5a

Browse files
authored
Merge pull request #12605 from panyx0718/ir
code clean up and renaming
2 parents 24283d9 + 626abfc commit 6b45c5a

15 files changed

+152
-180
lines changed

doc/fluid/design/ir/draft.md renamed to doc/fluid/design/ir/overview.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ graph = PassRegistry::Instance().Get("op_fuse_pass").Apply(std::move(grah));
177177
auto mem_opt_pass = PassRegistry::Instance().Get("memory_optimization_pass");
178178
mem_opt_pass.SetNotOwned<int>("optimize_level", 1);
179179
mem_opt_pass->Apply(std::move(graph));
180-
graph = PassRegistry::Instance().Get("multi_device_pass").Apply(std::move(grah));
181-
graph = PassRegistry::Instance().Get("multi_device_check_pass").Apply(std::move(grah));
180+
graph = PassRegistry::Instance().Get("multi_devices_pass").Apply(std::move(grah));
181+
graph = PassRegistry::Instance().Get("multi_devices_check_pass").Apply(std::move(grah));
182182
Executor exe;
183183
exe.Run(graph);
184184

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ else()
100100
endif()
101101

102102

103-
cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
103+
cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_pass multi_devices_graph_print_pass multi_devices_graph_check_pass)
104104

105105
cc_library(prune SRCS prune.cc DEPS framework_proto)
106106
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
55
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
66
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
77

8-
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph graph_helper)
9-
cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder)
10-
cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder)
8+
cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper)
9+
cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper)
10+
cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper)
1111

1212
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
1313

@@ -28,7 +28,7 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
2828
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
2929
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
3030

31-
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
31+
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
3232
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle)
3333

3434
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto)

paddle/fluid/framework/details/ssa_graph_checker.cc renamed to paddle/fluid/framework/details/multi_devices_graph_check_pass.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
15+
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
1616
#include <string>
1717
#include "paddle/fluid/framework/ir/graph.h"
1818

@@ -86,7 +86,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
8686
} // namespace framework
8787
} // namespace paddle
8888

89-
REGISTER_PASS(multi_device_check_pass,
89+
REGISTER_PASS(multi_devices_check_pass,
9090
paddle::framework::details::SSAGraghBuilderWithChecker)
9191
.RequireGraphAttr(paddle::framework::details::kGraphVars)
9292
.RequireGraphAttr(paddle::framework::details::kGraphDepVars)

paddle/fluid/framework/details/ssa_graph_checker.h renamed to paddle/fluid/framework/details/multi_devices_graph_check_pass.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414

1515
#pragma once
1616

17-
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
17+
#include "paddle/fluid/framework/details/multi_devices_helper.h"
1818

1919
#include <string>
2020

2121
namespace paddle {
2222
namespace framework {
2323
namespace details {
2424

25-
class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
25+
class SSAGraghBuilderWithChecker : public ir::Pass {
2626
protected:
2727
std::unique_ptr<ir::Graph> ApplyImpl(
2828
std::unique_ptr<ir::Graph> graph) const override {

paddle/fluid/framework/details/multi_devices_graph_builder.cc renamed to paddle/fluid/framework/details/multi_devices_graph_pass.cc

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
2222
#include "paddle/fluid/framework/details/computation_op_handle.h"
2323
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
24-
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
24+
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
2525
#include "paddle/fluid/framework/details/reduce_op_handle.h"
2626
#include "paddle/fluid/framework/details/rpc_op_handle.h"
2727
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
@@ -33,6 +33,92 @@
3333
namespace paddle {
3434
namespace framework {
3535
namespace details {
36+
namespace {
37+
void PolishGraphToSupportDataHazards(ir::Graph *graph) {
38+
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
39+
for (auto &name_pair : var_map) {
40+
if (name_pair.second.size() <= 1) {
41+
continue;
42+
}
43+
auto it_new = name_pair.second.rbegin();
44+
auto it_old = name_pair.second.rbegin();
45+
++it_old;
46+
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
47+
OpHandleBase *write_op = (*it_new)->GeneratedOp();
48+
const auto &read_ops = (*it_old)->PendingOps();
49+
50+
for (auto *read_op : read_ops) {
51+
// Manually add a dependency var from read_op to write_op;
52+
if (read_op == write_op) {
53+
// Read Write is the same op.
54+
continue;
55+
}
56+
bool has_dep = false;
57+
for (auto *r_out : read_op->Outputs()) {
58+
for (auto *w_in : write_op->Inputs()) {
59+
if (r_out->Node() == w_in->Node()) {
60+
has_dep = true;
61+
break;
62+
}
63+
}
64+
}
65+
if (has_dep) continue;
66+
67+
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
68+
read_op->AddOutput(dep_var);
69+
write_op->AddInput(dep_var);
70+
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
71+
}
72+
}
73+
}
74+
}
75+
}
76+
77+
VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
78+
const platform::Place &place,
79+
size_t place_offset) {
80+
auto &var_holders = graph->Get<GraphVars>(kGraphVars)[place_offset];
81+
auto &var_holder = var_holders[node->Name()];
82+
VarHandle *var = nullptr;
83+
if (var_holder.empty()) {
84+
if (node->Var()) {
85+
var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset,
86+
node->Name(), place);
87+
} else {
88+
var = new VarHandle(
89+
graph->CreateEmptyNode(node->Name(), ir::Node::Type::kVariable), 0,
90+
place_offset, node->Name(), place);
91+
}
92+
var_holder.emplace_back(var);
93+
} else {
94+
var = var_holder.rbegin()->get();
95+
}
96+
return var;
97+
}
98+
99+
void CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
100+
ir::Node *new_node, const platform::Place &place,
101+
size_t place_offset) {
102+
auto &vars =
103+
graph->Get<GraphVars>(kGraphVars)[place_offset][new_node->Name()];
104+
size_t version = vars.size();
105+
auto var =
106+
new VarHandle(new_node, version, place_offset, new_node->Name(), place);
107+
vars.emplace_back(var);
108+
op_handle->AddOutput(var);
109+
}
110+
111+
void AddOutputToLeafOps(ir::Graph *graph) {
112+
for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
113+
if (!op->Outputs().empty()) {
114+
continue;
115+
}
116+
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
117+
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
118+
op->AddOutput(dummy_leaf);
119+
}
120+
}
121+
} // namespace
36122

37123
static const char kLossVarName[] = "loss_var_name";
38124
static const char kPlaces[] = "places";
@@ -751,7 +837,7 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
751837
} // namespace framework
752838
} // namespace paddle
753839

754-
REGISTER_PASS(multi_device_pass,
840+
REGISTER_PASS(multi_devices_pass,
755841
paddle::framework::details::MultiDevSSAGraphBuilder)
756842
.RequirePassAttr(paddle::framework::details::kLossVarName)
757843
.RequirePassAttr(paddle::framework::details::kPlaces)

paddle/fluid/framework/details/multi_devices_graph_builder.h renamed to paddle/fluid/framework/details/multi_devices_graph_pass.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#include <vector>
1919

2020
#include "paddle/fluid/framework/details/build_strategy.h"
21-
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
21+
#include "paddle/fluid/framework/details/multi_devices_helper.h"
2222
#include "paddle/fluid/framework/ir/graph.h"
2323

2424
namespace paddle {
@@ -30,7 +30,7 @@ namespace framework {
3030
class Scope;
3131
namespace details {
3232

33-
class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
33+
class MultiDevSSAGraphBuilder : public ir::Pass {
3434
protected:
3535
std::unique_ptr<ir::Graph> ApplyImpl(
3636
std::unique_ptr<ir::Graph> graph) const override;

paddle/fluid/framework/details/ssa_graph_printer.cc renamed to paddle/fluid/framework/details/multi_devices_graph_print_pass.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
15+
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
1616
#include <string>
1717
#include "paddle/fluid/framework/ir/graph.h"
1818

@@ -82,5 +82,5 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
8282
} // namespace framework
8383
} // namespace paddle
8484

85-
REGISTER_PASS(multi_device_print_pass,
85+
REGISTER_PASS(multi_devices_print_pass,
8686
paddle::framework::details::SSAGraghBuilderWithPrinter);

paddle/fluid/framework/details/ssa_graph_printer.h renamed to paddle/fluid/framework/details/multi_devices_graph_print_pass.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#include <iosfwd>
1919
#include <ostream>
2020
#include <string>
21-
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
21+
#include "paddle/fluid/framework/details/multi_devices_helper.h"
2222

2323
namespace paddle {
2424
namespace framework {
@@ -35,7 +35,7 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
3535
void Print(const ir::Graph& graph, std::ostream& sout) const override;
3636
};
3737

38-
class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
38+
class SSAGraghBuilderWithPrinter : public ir::Pass {
3939
protected:
4040
std::unique_ptr<ir::Graph> ApplyImpl(
4141
std::unique_ptr<ir::Graph> graph) const override {
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#include "paddle/fluid/framework/details/multi_devices_helper.h"
15+
16+
namespace paddle {
17+
namespace framework {
18+
namespace details {} // namespace details
19+
} // namespace framework
20+
} // namespace paddle

0 commit comments

Comments
 (0)