Skip to content

Commit f4851f1

Browse files
committed
clean code
1 parent 22ab14c commit f4851f1

File tree

3 files changed

+80
-48
lines changed

3 files changed

+80
-48
lines changed

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 60 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
1616

17-
#include "paddle/fluid/framework/details/fetch_op_handle.h"
18-
1917
namespace paddle {
2018
namespace framework {
2119
namespace details {
@@ -45,73 +43,33 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
4543
// Should revisit it if overlapping is available.
4644
std::unordered_set<OpHandleBase *> delayed_ops;
4745

48-
auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
49-
pending_vars.insert(&var);
50-
if (var.generated_op_ == nullptr) {
51-
ready_vars.Push(&var);
52-
}
53-
};
54-
55-
auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) {
56-
pending_ops.insert({&op_instance, op_instance.Inputs().size()});
57-
};
58-
5946
// Transform SSAGraph to pending_ops & pending_vars
6047
for (auto &var_map : graph_->vars_) {
6148
for (auto &name_pair : var_map) {
6249
for (auto &version_pair : name_pair.second) {
63-
InsertPendingVar(*version_pair);
50+
InsertPendingVar(&pending_vars, &ready_vars, version_pair.get());
6451
}
6552
}
6653
}
6754
for (auto &var : graph_->dep_vars_) {
68-
InsertPendingVar(*var);
55+
InsertPendingVar(&pending_vars, &ready_vars, var.get());
6956
}
7057

7158
for (auto &op : graph_->ops_) {
7259
if (op->Inputs().empty()) { // Special case, Op has no input.
7360
ready_ops.insert(op.get());
7461
} else {
75-
InsertPendingOp(*op);
62+
InsertPendingOp(&pending_ops, op.get());
7663
}
7764
}
7865

7966
// Step 2. Insert FetchOps
8067
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
81-
FeedFetchList fetch_data(fetch_tensors.size());
82-
83-
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
84-
85-
for (auto &fetch_var_name : fetch_tensors) {
86-
for (auto &var_map : graph_->vars_) {
87-
auto it = var_map.find(fetch_var_name);
88-
if (it != var_map.end()) {
89-
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
90-
}
91-
}
92-
}
93-
9468
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
95-
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
96-
auto &var_name = fetch_tensors[i];
97-
auto &vars = fetched_vars.at(var_name);
98-
auto *op = new FetchOpHandle(&fetch_data, i, &local_scopes_);
99-
fetch_ops.emplace_back(op);
100-
101-
for (auto &p : places_) {
102-
op->SetDeviceContext(p, fetch_ctxs_.Get(p));
103-
}
104-
105-
for (auto *var : vars) {
106-
op->AddInput(var);
107-
}
69+
FeedFetchList fetch_data(fetch_tensors.size());
10870

109-
auto *fetch_dummy = new DummyVarHandle();
110-
op->AddOutput(fetch_dummy);
111-
fetch_dependencies.emplace(fetch_dummy);
112-
InsertPendingVar(*fetch_dummy);
113-
InsertPendingOp(*op);
114-
}
71+
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
72+
&pending_vars, &ready_vars, &fetch_data);
11573

11674
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
11775
for (auto *op : set) {
@@ -174,6 +132,60 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
174132
return fetch_data;
175133
}
176134

135+
void ThreadedSSAGraphExecutor::InsertFetchOps(
136+
const std::vector<std::string> &fetch_tensors,
137+
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
138+
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
139+
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
140+
std::unordered_set<VarHandleBase *> *pending_vars,
141+
BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data) {
142+
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
143+
144+
for (auto &fetch_var_name : fetch_tensors) {
145+
for (auto &var_map : graph_->vars_) {
146+
auto it = var_map.find(fetch_var_name);
147+
if (it != var_map.end()) {
148+
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
149+
}
150+
}
151+
}
152+
153+
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
154+
auto &var_name = fetch_tensors[i];
155+
auto &vars = fetched_vars.at(var_name);
156+
auto *op = new FetchOpHandle(fetch_data, i, &local_scopes_);
157+
fetch_ops->emplace_back(op);
158+
159+
for (auto &p : places_) {
160+
op->SetDeviceContext(p, fetch_ctxs_.Get(p));
161+
}
162+
163+
for (auto *var : vars) {
164+
op->AddInput(var);
165+
}
166+
167+
auto *fetch_dummy = new DummyVarHandle();
168+
op->AddOutput(fetch_dummy);
169+
fetch_dependencies->emplace(fetch_dummy);
170+
this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy);
171+
this->InsertPendingOp(pending_ops, op);
172+
}
173+
}
174+
175+
void ThreadedSSAGraphExecutor::InsertPendingOp(
176+
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
177+
OpHandleBase *op_instance) const {
178+
pending_ops->insert({op_instance, op_instance->Inputs().size()});
179+
}
180+
181+
void ThreadedSSAGraphExecutor::InsertPendingVar(
182+
std::unordered_set<VarHandleBase *> *pending_vars,
183+
BlockingQueue<VarHandleBase *> *ready_vars, VarHandleBase *var) const {
184+
pending_vars->insert(var);
185+
if (var->generated_op_ == nullptr) {
186+
ready_vars->Push(var);
187+
}
188+
}
177189
void ThreadedSSAGraphExecutor::RunOp(
178190
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
179191
auto op_run = [ready_var_q, op, this] {

paddle/fluid/framework/details/threaded_ssa_graph_executor.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <functional>
2424
#include "ThreadPool.h" // ThreadPool in thrird party
2525
#include "paddle/fluid/framework/blocking_queue.h"
26+
#include "paddle/fluid/framework/details/fetch_op_handle.h"
2627
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
2728

2829
namespace paddle {
@@ -58,6 +59,21 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
5859
std::unique_ptr<platform::EnforceNotMet> exception_;
5960
std::atomic<int> running_ops_;
6061
bool allow_op_delay_;
62+
63+
void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops,
64+
OpHandleBase *op_instance) const;
65+
66+
void InsertPendingVar(std::unordered_set<VarHandleBase *> *pending_vars,
67+
BlockingQueue<VarHandleBase *> *ready_vars,
68+
VarHandleBase *var) const;
69+
70+
void InsertFetchOps(
71+
const std::vector<std::string> &fetch_tensors,
72+
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
73+
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
74+
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
75+
std::unordered_set<VarHandleBase *> *pending_vars,
76+
BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data);
6177
};
6278

6379
} // namespace details

python/paddle/fluid/tests/unittests/test_parallel_executor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,3 +721,7 @@ def test_update_sparse_parameter(self):
721721

722722
def test_update_dense_parameter(self):
723723
self.check_network_convergence(is_sparse=False)
724+
725+
726+
if __name__ == '__main__':
727+
unittest.main()

0 commit comments

Comments
 (0)