@@ -27,20 +27,6 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
27
27
VLOG (4 ) << " RunServer thread end" ;
28
28
}
29
29
30
- static void CreateTensorFromMessageType (framework::Variable *var,
31
- sendrecv::VarType var_type) {
32
- if (var_type == sendrecv::VarType::LOD_TENSOR) {
33
- var->GetMutable <framework::LoDTensor>();
34
- } else if (var_type == sendrecv::VarType::SELECTED_ROWS) {
35
- var->GetMutable <framework::SelectedRows>();
36
- } else {
37
- PADDLE_THROW (
38
- " VariableMessage type %d is not in "
39
- " [LoDTensor, SelectedRows]" ,
40
- var_type);
41
- }
42
- }
43
-
44
30
static void ParallelExecuteBlocks (
45
31
const std::vector<size_t > ¶llel_blkids, framework::Executor *executor,
46
32
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
@@ -77,59 +63,37 @@ void ListenAndServOp::Stop() {
77
63
server_thread_->join ();
78
64
}
79
65
80
- void ListenAndServOp::RunImpl (const framework::Scope &scope,
81
- const platform::Place &dev_place) const {
82
- platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance ();
83
- auto &dev_ctx = *pool.Get (dev_place);
84
- framework::Scope &recv_scope = scope.NewScope ();
85
-
86
- if (!rpc_service_) {
87
- std::string endpoint = Attr<std::string>(" endpoint" );
88
- rpc_service_.reset (new detail::AsyncGRPCServer (endpoint));
89
- }
66
+ void ListenAndServOp::PreparePrefetchCtx (
67
+ framework::Executor *executor, framework::BlockDesc *prefetch_block,
68
+ framework::ProgramDesc *program) const {
69
+ // TODO(qiao) set proper fields for table lookup and update
70
+ rpc_service_->SetExecutor (executor);
71
+ VLOG (3 ) << " prefetch block id is " << prefetch_block->ID ();
72
+ auto prefetch_prepared = executor->Prepare (*program, prefetch_block->ID ());
73
+ rpc_service_->SetPrefetchBlkdId (prefetch_block->ID ());
74
+ rpc_service_->SetPrefetchPreparedCtx (prefetch_prepared.get ());
75
+ prefetch_prepared.release ();
76
+ }
90
77
91
- auto ins = Inputs (" X" );
78
+ void ListenAndServOp::RunSyncUpdate (
79
+ framework::Executor *executor, framework::ProgramDesc *program,
80
+ framework::Scope *recv_scope, framework::BlockDesc *prefetch_block) const {
92
81
auto fan_in = Attr<int >(" Fanin" );
93
- auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock );
94
- auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock );
95
- auto *program = optimize_block->Program ();
82
+
96
83
size_t num_blocks = program->Size ();
97
84
PADDLE_ENFORCE_GE (num_blocks, 2 ,
98
85
" server program should have at least 2 blocks" );
99
86
100
- framework::Executor executor (dev_place);
101
87
std::vector<int > block_list;
102
88
for (size_t blkid = 1 ; blkid < num_blocks; ++blkid) {
103
- if (blkid != static_cast <size_t >(prefetch_block->ID ())) {
104
- block_list.push_back (blkid);
105
- }
89
+ block_list.push_back (blkid);
106
90
}
107
- auto optimize_prepared = executor. Prepare (*program, block_list);
91
+ auto optimize_prepared = executor-> Prepare (*program, block_list);
108
92
// Insert placeholder for block0 which holds current op itself.
109
93
optimize_prepared.insert (
110
94
optimize_prepared.begin (),
111
95
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr ));
112
96
113
- rpc_service_->SetScope (&recv_scope);
114
- rpc_service_->SetDevCtx (&dev_ctx);
115
- // TODO(qiao) set proper fields for table lookup and update
116
- rpc_service_->SetExecutor (&executor);
117
- VLOG (3 ) << " prefetch block id is " << prefetch_block->ID ();
118
- auto prefetch_prepared = executor.Prepare (*program, prefetch_block->ID ());
119
- rpc_service_->SetPrefetchBlkdId (prefetch_block->ID ());
120
- rpc_service_->SetPrefetchPreparedCtx (prefetch_prepared.get ());
121
- prefetch_prepared.release ();
122
- rpc_service_->SetProgram (program);
123
- // start the server listening after all member initialized.
124
- server_thread_.reset (new std::thread (RunServer, rpc_service_));
125
- VLOG (3 ) << " wait server thread to become ready..." ;
126
- sleep (5 );
127
- // Write to a file of server selected port for python use.
128
- std::ofstream port_file;
129
- port_file.open (" /tmp/paddle.selected_port" );
130
- port_file << rpc_service_->GetSelectedPort ();
131
- port_file.close ();
132
-
133
97
bool exit_flag = false ;
134
98
// Record received sparse variables, so that
135
99
// we could reset those after execute optimize program
@@ -170,7 +134,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
170
134
break ;
171
135
}
172
136
173
- // NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
137
+ // NOTE: if is_gpu_place, CUDA kernels are launch by multiple threads
174
138
// and this will still work.
175
139
176
140
// The optimize blocks which have the same parent ID would run parallel
@@ -182,16 +146,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
182
146
for (size_t blkid = 2 ; blkid < num_blocks; ++blkid) {
183
147
if (blkid != static_cast <size_t >(prefetch_block->ID ())) {
184
148
if (program->Block (blkid).Parent () != last_parent_blkid) {
185
- ParallelExecuteBlocks (parallel_blkids, & executor, optimize_prepared,
186
- program, & recv_scope);
149
+ ParallelExecuteBlocks (parallel_blkids, executor, optimize_prepared,
150
+ program, recv_scope);
187
151
parallel_blkids.clear ();
188
152
last_parent_blkid = program->Block (blkid).Parent ();
189
153
}
190
154
parallel_blkids.push_back (blkid);
191
155
}
192
156
}
193
- ParallelExecuteBlocks (parallel_blkids, & executor, optimize_prepared,
194
- program, & recv_scope);
157
+ ParallelExecuteBlocks (parallel_blkids, executor, optimize_prepared, program ,
158
+ recv_scope);
195
159
VLOG (2 ) << " run all blocks spent " << detail::GetTimestamp () - ts << " (ms)" ;
196
160
197
161
// Reset the received sparse variables, the sum operator would not
@@ -209,6 +173,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
209
173
} // while(true)
210
174
}
211
175
176
+ static void SavePort (std::shared_ptr<detail::AsyncGRPCServer> rpc_service) {
177
+ std::ofstream port_file;
178
+ port_file.open (" /tmp/paddle.selected_port" );
179
+ port_file << rpc_service->GetSelectedPort ();
180
+ port_file.close ();
181
+ }
182
+
183
+ void ListenAndServOp::RunImpl (const framework::Scope &scope,
184
+ const platform::Place &dev_place) const {
185
+ platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance ();
186
+ auto &dev_ctx = *pool.Get (dev_place);
187
+ framework::Scope &recv_scope = scope.NewScope ();
188
+
189
+ PADDLE_ENFORCE (!rpc_service_);
190
+ std::string endpoint = Attr<std::string>(" endpoint" );
191
+ rpc_service_.reset (new detail::AsyncGRPCServer (endpoint));
192
+
193
+ auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock );
194
+ auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock );
195
+ auto *program = optimize_block->Program ();
196
+ framework::Executor executor (dev_place);
197
+
198
+ // prepare rpc_service
199
+ rpc_service_->SetScope (&recv_scope);
200
+ rpc_service_->SetDevCtx (&dev_ctx);
201
+ rpc_service_->SetProgram (program);
202
+ PreparePrefetchCtx (&executor, prefetch_block, program);
203
+ // start the server listening after all member initialized.
204
+ server_thread_.reset (new std::thread (RunServer, rpc_service_));
205
+ VLOG (3 ) << " wait server thread to become ready..." ;
206
+ sleep (5 );
207
+ // Write to a file of server selected port for python use.
208
+ SavePort (rpc_service_);
209
+ RunSyncUpdate (&executor, program, &recv_scope, prefetch_block);
210
+ }
211
+
212
212
class ListenAndServOpMaker : public framework ::OpProtoAndCheckerMaker {
213
213
public:
214
214
ListenAndServOpMaker (OpProto *proto, OpAttrChecker *op_checker)
0 commit comments