Skip to content

Commit b53f7e2

Browse files
authored
Merge pull request #9930 from reyoung/feature/simplify_delay_logic
Simplify DelayOps Logic
2 parents 0729ea7 + 4999f85 commit b53f7e2

File tree

3 files changed

+18
-39
lines changed

3 files changed

+18
-39
lines changed

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,6 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
3333
running_ops_(0),
3434
allow_op_delay_(allow_op_delay) {}
3535

36-
void ThreadedSSAGraphExecutor::RunDelayedOps(
37-
const std::unordered_set<OpHandleBase *> &delayed_ops) {
38-
for (auto op : delayed_ops) {
39-
op->Run(use_event_);
40-
}
41-
}
42-
4336
FeedFetchList ThreadedSSAGraphExecutor::Run(
4437
const std::vector<std::string> &fetch_tensors) {
4538
std::unordered_map<OpHandleBase *, size_t> pending_ops;
@@ -51,8 +44,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
5144
// together since we currently cannot overlap computation and memcpy streams.
5245
// Should revisit it if overlapping is available.
5346
std::unordered_set<OpHandleBase *> delayed_ops;
54-
std::unordered_set<OpHandleBase *> blocked_by_delayed_ops;
55-
std::unordered_set<VarHandleBase *> delayed_vars;
5647

5748
auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
5849
pending_vars.insert(&var);
@@ -122,24 +113,26 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
122113
InsertPendingOp(*op);
123114
}
124115

125-
auto run_all_ready_ops = [&] {
126-
for (auto *op : ready_ops) {
127-
if (op->IsMultiDeviceTransfer() && allow_op_delay_) {
128-
delayed_ops.insert(op);
129-
delayed_vars.insert(op->outputs_.begin(), op->outputs_.end());
130-
ready_vars.Extend(op->outputs_);
131-
continue;
132-
}
116+
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
117+
for (auto *op : set) {
133118
running_ops_++;
134119
RunOp(&ready_vars, op);
135120
}
136-
ready_ops.clear();
121+
set.clear();
137122
};
138123

139124
// Step 3. Execution
140-
while (!pending_vars.empty() || !ready_ops.empty() || !delayed_ops.empty()) {
125+
while (!pending_vars.empty()) {
141126
// 1. Run All Ready ops
142-
run_all_ready_ops();
127+
// Keep loop until all vars are ready.
128+
//
129+
// NOTE: DelayedOps have a lower priority. It will be scheduled after all
130+
// ready_ops have been performed.
131+
if (ready_ops.empty() && allow_op_delay_) {
132+
run_all_ops(delayed_ops);
133+
} else {
134+
run_all_ops(ready_ops);
135+
}
143136

144137
// 2. Find ready variable
145138
bool timeout;
@@ -160,29 +153,16 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
160153
auto &deps = pending_ops[op];
161154
--deps;
162155
if (deps == 0) {
163-
if (delayed_vars.find(ready_var) != delayed_vars.end()) {
164-
blocked_by_delayed_ops.insert(op);
156+
if (op->IsMultiDeviceTransfer() && allow_op_delay_) {
157+
delayed_ops.insert(op);
165158
} else {
166159
ready_ops.insert(op);
167160
}
168161
}
169162
}
170163
}
171-
// When there are no other ops to schedule, schedule buffered delayed
172-
// ops and unblock other ops.
173-
if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) {
174-
RunDelayedOps(delayed_ops);
175-
delayed_ops.clear();
176-
for (auto *op : blocked_by_delayed_ops) {
177-
ready_ops.insert(op);
178-
}
179-
blocked_by_delayed_ops.clear();
180-
}
181-
// Keep loop until all vars are ready.
182164
}
183165
PADDLE_ENFORCE(ready_ops.empty());
184-
PADDLE_ENFORCE(delayed_ops.empty());
185-
PADDLE_ENFORCE(blocked_by_delayed_ops.empty());
186166

187167
// Wait FetchOps.
188168
if (!fetch_ops.empty()) {

paddle/fluid/framework/details/threaded_ssa_graph_executor.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,6 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
8888
void RunOp(BlockingQueue<VarHandleBase *> *ready_var_q,
8989
details::OpHandleBase *op);
9090

91-
void RunDelayedOps(const std::unordered_set<OpHandleBase *> &delayed_ops);
92-
9391
private:
9492
std::unique_ptr<::ThreadPool> pool_;
9593
std::vector<Scope *> local_scopes_;

python/paddle/fluid/tests/unittests/test_parallel_executor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,18 +206,19 @@ def check_network_convergence(self,
206206
feed_dict={}):
207207
main = fluid.Program()
208208
startup = fluid.Program()
209+
startup.random_seed = 1 # Fix random seed
209210
with fluid.program_guard(main, startup):
210211
loss = method(use_feed=len(feed_dict) > 0)
211212
adam = fluid.optimizer.Adam()
212213
adam.minimize(loss)
213214
if memory_opt:
214215
fluid.memory_optimize(main)
215-
216216
place = fluid.CUDAPlace(0)
217217
startup_exe = fluid.Executor(place)
218218
startup_exe.run(startup)
219219

220-
exe = fluid.ParallelExecutor(True, loss_name=loss.name)
220+
exe = fluid.ParallelExecutor(
221+
True, loss_name=loss.name, allow_op_delay=allow_op_delay)
221222
if batch_size is not None:
222223
batch_size *= fluid.core.get_cuda_device_count()
223224
begin = time.time()

0 commit comments

Comments
 (0)