|
14 | 14 |
|
15 | 15 | #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
|
16 | 16 |
|
17 |
| -#include "paddle/fluid/framework/details/fetch_op_handle.h" |
18 |
| - |
19 | 17 | namespace paddle {
|
20 | 18 | namespace framework {
|
21 | 19 | namespace details {
|
@@ -45,73 +43,33 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
45 | 43 | // Should revisit it if overlapping is available.
|
46 | 44 | std::unordered_set<OpHandleBase *> delayed_ops;
|
47 | 45 |
|
48 |
| - auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) { |
49 |
| - pending_vars.insert(&var); |
50 |
| - if (var.generated_op_ == nullptr) { |
51 |
| - ready_vars.Push(&var); |
52 |
| - } |
53 |
| - }; |
54 |
| - |
55 |
| - auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) { |
56 |
| - pending_ops.insert({&op_instance, op_instance.Inputs().size()}); |
57 |
| - }; |
58 |
| - |
59 | 46 | // Transform SSAGraph to pending_ops & pending_vars
|
60 | 47 | for (auto &var_map : graph_->vars_) {
|
61 | 48 | for (auto &name_pair : var_map) {
|
62 | 49 | for (auto &version_pair : name_pair.second) {
|
63 |
| - InsertPendingVar(*version_pair); |
| 50 | + InsertPendingVar(&pending_vars, &ready_vars, version_pair.get()); |
64 | 51 | }
|
65 | 52 | }
|
66 | 53 | }
|
67 | 54 | for (auto &var : graph_->dep_vars_) {
|
68 |
| - InsertPendingVar(*var); |
| 55 | + InsertPendingVar(&pending_vars, &ready_vars, var.get()); |
69 | 56 | }
|
70 | 57 |
|
71 | 58 | for (auto &op : graph_->ops_) {
|
72 | 59 | if (op->Inputs().empty()) { // Special case, Op has no input.
|
73 | 60 | ready_ops.insert(op.get());
|
74 | 61 | } else {
|
75 |
| - InsertPendingOp(*op); |
| 62 | + InsertPendingOp(&pending_ops, op.get()); |
76 | 63 | }
|
77 | 64 | }
|
78 | 65 |
|
79 | 66 | // Step 2. Insert FetchOps
|
80 | 67 | std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
|
81 |
| - FeedFetchList fetch_data(fetch_tensors.size()); |
82 |
| - |
83 |
| - std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; |
84 |
| - |
85 |
| - for (auto &fetch_var_name : fetch_tensors) { |
86 |
| - for (auto &var_map : graph_->vars_) { |
87 |
| - auto it = var_map.find(fetch_var_name); |
88 |
| - if (it != var_map.end()) { |
89 |
| - fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); |
90 |
| - } |
91 |
| - } |
92 |
| - } |
93 |
| - |
94 | 68 | std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
|
95 |
| - for (size_t i = 0; i < fetch_tensors.size(); ++i) { |
96 |
| - auto &var_name = fetch_tensors[i]; |
97 |
| - auto &vars = fetched_vars.at(var_name); |
98 |
| - auto *op = new FetchOpHandle(&fetch_data, i, &local_scopes_); |
99 |
| - fetch_ops.emplace_back(op); |
100 |
| - |
101 |
| - for (auto &p : places_) { |
102 |
| - op->SetDeviceContext(p, fetch_ctxs_.Get(p)); |
103 |
| - } |
104 |
| - |
105 |
| - for (auto *var : vars) { |
106 |
| - op->AddInput(var); |
107 |
| - } |
| 69 | + FeedFetchList fetch_data(fetch_tensors.size()); |
108 | 70 |
|
109 |
| - auto *fetch_dummy = new DummyVarHandle(); |
110 |
| - op->AddOutput(fetch_dummy); |
111 |
| - fetch_dependencies.emplace(fetch_dummy); |
112 |
| - InsertPendingVar(*fetch_dummy); |
113 |
| - InsertPendingOp(*op); |
114 |
| - } |
| 71 | + InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops, |
| 72 | + &pending_vars, &ready_vars, &fetch_data); |
115 | 73 |
|
116 | 74 | auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
|
117 | 75 | for (auto *op : set) {
|
@@ -174,6 +132,60 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
|
174 | 132 | return fetch_data;
|
175 | 133 | }
|
176 | 134 |
|
| 135 | +void ThreadedSSAGraphExecutor::InsertFetchOps( |
| 136 | + const std::vector<std::string> &fetch_tensors, |
| 137 | + std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops, |
| 138 | + std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies, |
| 139 | + std::unordered_map<OpHandleBase *, size_t> *pending_ops, |
| 140 | + std::unordered_set<VarHandleBase *> *pending_vars, |
| 141 | + BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data) { |
| 142 | + std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; |
| 143 | + |
| 144 | + for (auto &fetch_var_name : fetch_tensors) { |
| 145 | + for (auto &var_map : graph_->vars_) { |
| 146 | + auto it = var_map.find(fetch_var_name); |
| 147 | + if (it != var_map.end()) { |
| 148 | + fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); |
| 149 | + } |
| 150 | + } |
| 151 | + } |
| 152 | + |
| 153 | + for (size_t i = 0; i < fetch_tensors.size(); ++i) { |
| 154 | + auto &var_name = fetch_tensors[i]; |
| 155 | + auto &vars = fetched_vars.at(var_name); |
| 156 | + auto *op = new FetchOpHandle(fetch_data, i, &local_scopes_); |
| 157 | + fetch_ops->emplace_back(op); |
| 158 | + |
| 159 | + for (auto &p : places_) { |
| 160 | + op->SetDeviceContext(p, fetch_ctxs_.Get(p)); |
| 161 | + } |
| 162 | + |
| 163 | + for (auto *var : vars) { |
| 164 | + op->AddInput(var); |
| 165 | + } |
| 166 | + |
| 167 | + auto *fetch_dummy = new DummyVarHandle(); |
| 168 | + op->AddOutput(fetch_dummy); |
| 169 | + fetch_dependencies->emplace(fetch_dummy); |
| 170 | + this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy); |
| 171 | + this->InsertPendingOp(pending_ops, op); |
| 172 | + } |
| 173 | +} |
| 174 | + |
| 175 | +void ThreadedSSAGraphExecutor::InsertPendingOp( |
| 176 | + std::unordered_map<OpHandleBase *, size_t> *pending_ops, |
| 177 | + OpHandleBase *op_instance) const { |
| 178 | + pending_ops->insert({op_instance, op_instance->Inputs().size()}); |
| 179 | +} |
| 180 | + |
| 181 | +void ThreadedSSAGraphExecutor::InsertPendingVar( |
| 182 | + std::unordered_set<VarHandleBase *> *pending_vars, |
| 183 | + BlockingQueue<VarHandleBase *> *ready_vars, VarHandleBase *var) const { |
| 184 | + pending_vars->insert(var); |
| 185 | + if (var->generated_op_ == nullptr) { |
| 186 | + ready_vars->Push(var); |
| 187 | + } |
| 188 | +} |
177 | 189 | void ThreadedSSAGraphExecutor::RunOp(
|
178 | 190 | BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
|
179 | 191 | auto op_run = [ready_var_q, op, this] {
|
|
0 commit comments