@@ -33,13 +33,6 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
33
33
running_ops_(0 ),
34
34
allow_op_delay_(allow_op_delay) {}
35
35
36
- void ThreadedSSAGraphExecutor::RunDelayedOps (
37
- const std::unordered_set<OpHandleBase *> &delayed_ops) {
38
- for (auto op : delayed_ops) {
39
- op->Run (use_event_);
40
- }
41
- }
42
-
43
36
FeedFetchList ThreadedSSAGraphExecutor::Run (
44
37
const std::vector<std::string> &fetch_tensors) {
45
38
std::unordered_map<OpHandleBase *, size_t > pending_ops;
@@ -51,8 +44,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
51
44
// together since we currently cannot overlap computation and memcpy streams.
52
45
// Should revisit it if overlapping is available.
53
46
std::unordered_set<OpHandleBase *> delayed_ops;
54
- std::unordered_set<OpHandleBase *> blocked_by_delayed_ops;
55
- std::unordered_set<VarHandleBase *> delayed_vars;
56
47
57
48
auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
58
49
pending_vars.insert (&var);
@@ -122,24 +113,26 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
122
113
InsertPendingOp (*op);
123
114
}
124
115
125
- auto run_all_ready_ops = [&] {
126
- for (auto *op : ready_ops) {
127
- if (op->IsMultiDeviceTransfer () && allow_op_delay_) {
128
- delayed_ops.insert (op);
129
- delayed_vars.insert (op->outputs_ .begin (), op->outputs_ .end ());
130
- ready_vars.Extend (op->outputs_ );
131
- continue ;
132
- }
116
+ auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
117
+ for (auto *op : set) {
133
118
running_ops_++;
134
119
RunOp (&ready_vars, op);
135
120
}
136
- ready_ops .clear ();
121
+ set .clear ();
137
122
};
138
123
139
124
// Step 3. Execution
140
- while (!pending_vars.empty () || !ready_ops. empty () || !delayed_ops. empty () ) {
125
+ while (!pending_vars.empty ()) {
141
126
// 1. Run All Ready ops
142
- run_all_ready_ops ();
127
+ // Keep loop until all vars are ready.
128
+ //
129
+ // NOTE: DelayedOps have a lower priority. It will be scheduled after all
130
+ // ready_ops have been performed.
131
+ if (ready_ops.empty () && allow_op_delay_) {
132
+ run_all_ops (delayed_ops);
133
+ } else {
134
+ run_all_ops (ready_ops);
135
+ }
143
136
144
137
// 2. Find ready variable
145
138
bool timeout;
@@ -160,29 +153,16 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
160
153
auto &deps = pending_ops[op];
161
154
--deps;
162
155
if (deps == 0 ) {
163
- if (delayed_vars. find (ready_var) != delayed_vars. end () ) {
164
- blocked_by_delayed_ops .insert (op);
156
+ if (op-> IsMultiDeviceTransfer () && allow_op_delay_ ) {
157
+ delayed_ops .insert (op);
165
158
} else {
166
159
ready_ops.insert (op);
167
160
}
168
161
}
169
162
}
170
163
}
171
- // When there are no other ops to schedule, schedule buffered delayed
172
- // ops and unblock other ops.
173
- if (ready_ops.empty () && !delayed_ops.empty () && running_ops_ == 0 ) {
174
- RunDelayedOps (delayed_ops);
175
- delayed_ops.clear ();
176
- for (auto *op : blocked_by_delayed_ops) {
177
- ready_ops.insert (op);
178
- }
179
- blocked_by_delayed_ops.clear ();
180
- }
181
- // Keep loop until all vars are ready.
182
164
}
183
165
PADDLE_ENFORCE (ready_ops.empty ());
184
- PADDLE_ENFORCE (delayed_ops.empty ());
185
- PADDLE_ENFORCE (blocked_by_delayed_ops.empty ());
186
166
187
167
// Wait FetchOps.
188
168
if (!fetch_ops.empty ()) {
0 commit comments