@@ -54,20 +54,24 @@ static void CreateTensorFromMessageType(framework::Variable *var,
54
54
}
55
55
}
56
56
57
- static void ParallelExecuteBlocks (const std::vector<size_t > ¶llel_blkids,
58
- framework::Executor *executor,
59
- framework::ProgramDesc *program,
60
- framework::Scope *scope) {
57
+ static void ParallelExecuteBlocks (
58
+ const std::vector<size_t > ¶llel_blkids, framework::Executor *executor,
59
+ const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
60
+ &prepared,
61
+ framework::ProgramDesc *program, framework::Scope *scope) {
61
62
std::vector<std::future<void >> fs;
62
63
for (size_t idx : parallel_blkids) {
63
- fs.push_back (framework::Async ([&executor, &program, &scope, idx]() {
64
- int run_block = idx; // thread local
65
- try {
66
- executor->Run (*program, scope, run_block, false , false );
67
- } catch (std::exception &e) {
68
- LOG (ERROR) << " run sub program error " << e.what ();
69
- }
70
- }));
64
+ fs.push_back (
65
+ framework::Async ([&executor, &prepared, &program, &scope, idx]() {
66
+ int run_block = idx; // thread local
67
+ try {
68
+ // executor->Run(*program, scope, run_block, false, false);
69
+ executor->RunPreparedContext (prepared[run_block].get (), scope,
70
+ false , false );
71
+ } catch (std::exception &e) {
72
+ LOG (ERROR) << " run sub program error " << e.what ();
73
+ }
74
+ }));
71
75
}
72
76
for (size_t i = 0 ; i < fs.size (); ++i) fs[i].wait ();
73
77
}
@@ -105,15 +109,18 @@ class ListenAndServOp : public framework::OperatorBase {
105
109
106
110
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock );
107
111
auto *program = block->Program ();
108
- int num_blocks = program->Size ();
112
+ size_t num_blocks = program->Size ();
109
113
PADDLE_ENFORCE_GE (num_blocks, 2 ,
110
114
" server program should have at least 2 blocks" );
111
115
112
116
framework::Executor executor (dev_place);
113
117
std::vector<int > block_list;
114
- for (int blkid = 1 ; blkid < num_blocks; ++blkid)
118
+ for (size_t blkid = 1 ; blkid < num_blocks; ++blkid)
115
119
block_list.push_back (blkid);
116
120
auto prepared = executor.Prepare (*program, block_list);
121
+ prepared.insert (
122
+ prepared.begin (),
123
+ std::shared_ptr<framework::ExecutorPrepareContext>(nullptr ));
117
124
118
125
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
119
126
bool exit_flag = false ;
@@ -161,21 +168,22 @@ class ListenAndServOp : public framework::OperatorBase {
161
168
162
169
// The optimize blocks which have the same parent ID would run parallel
163
170
// TODO(Yancey1989): need to use ParallelExecutor for future
164
- size_t last_parent_blkid = program->Block (1 ).Parent ();
171
+ int32_t last_parent_blkid = program->Block (1 ).Parent ();
165
172
std::vector<size_t > parallel_blkids;
166
173
parallel_blkids.push_back (1 );
167
174
double ts = detail::GetTimestamp ();
168
175
for (size_t blkid = 2 ; blkid < num_blocks; ++blkid) {
169
176
if (program->Block (blkid).Parent () != last_parent_blkid) {
170
177
for (size_t idx : parallel_blkids) VLOG (3 ) << idx;
171
- ParallelExecuteBlocks (parallel_blkids, &executor, program,
178
+ ParallelExecuteBlocks (parallel_blkids, &executor, prepared, program,
172
179
&recv_scope);
173
180
parallel_blkids.clear ();
174
181
last_parent_blkid = program->Block (blkid).Parent ();
175
182
}
176
183
parallel_blkids.push_back (blkid);
177
184
}
178
- ParallelExecuteBlocks (parallel_blkids, &executor, program, &recv_scope);
185
+ ParallelExecuteBlocks (parallel_blkids, &executor, prepared, program,
186
+ &recv_scope);
179
187
180
188
VLOG (2 ) << " run all blocks spent (ms) " << detail::GetTimestamp () - ts;
181
189
0 commit comments