@@ -27,20 +27,6 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
27
27
VLOG (4 ) << " RunServer thread end" ;
28
28
}
29
29
30
- static void CreateTensorFromMessageType (framework::Variable *var,
31
- sendrecv::VarType var_type) {
32
- if (var_type == sendrecv::VarType::LOD_TENSOR) {
33
- var->GetMutable <framework::LoDTensor>();
34
- } else if (var_type == sendrecv::VarType::SELECTED_ROWS) {
35
- var->GetMutable <framework::SelectedRows>();
36
- } else {
37
- PADDLE_THROW (
38
- " VariableMessage type %d is not in "
39
- " [LoDTensor, SelectedRows]" ,
40
- var_type);
41
- }
42
- }
43
-
44
30
static void ParallelExecuteBlocks (
45
31
const std::vector<size_t > ¶llel_blkids, framework::Executor *executor,
46
32
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
@@ -62,6 +48,13 @@ static void ParallelExecuteBlocks(
62
48
for (size_t i = 0 ; i < fs.size (); ++i) fs[i].wait ();
63
49
}
64
50
51
+ static void SavePort (std::shared_ptr<detail::AsyncGRPCServer> rpc_service) {
52
+ std::ofstream port_file;
53
+ port_file.open (" /tmp/paddle.selected_port" );
54
+ port_file << rpc_service->GetSelectedPort ();
55
+ port_file.close ();
56
+ }
57
+
65
58
ListenAndServOp::ListenAndServOp (const std::string &type,
66
59
const framework::VariableNameMap &inputs,
67
60
const framework::VariableNameMap &outputs,
@@ -77,59 +70,26 @@ void ListenAndServOp::Stop() {
77
70
server_thread_->join ();
78
71
}
79
72
80
- void ListenAndServOp::RunImpl (const framework::Scope &scope,
81
- const platform::Place &dev_place) const {
82
- platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance ();
83
- auto &dev_ctx = *pool.Get (dev_place);
84
- framework::Scope &recv_scope = scope.NewScope ();
85
-
86
- if (!rpc_service_) {
87
- std::string endpoint = Attr<std::string>(" endpoint" );
88
- rpc_service_.reset (new detail::AsyncGRPCServer (endpoint));
89
- }
90
-
91
- auto ins = Inputs (" X" );
73
+ void ListenAndServOp::RunSyncLoop (framework::Executor *executor,
74
+ framework::ProgramDesc *program,
75
+ framework::Scope *recv_scope,
76
+ framework::BlockDesc *prefetch_block) const {
92
77
auto fan_in = Attr<int >(" Fanin" );
93
- auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock );
94
- auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock );
95
- auto *program = optimize_block->Program ();
78
+
96
79
size_t num_blocks = program->Size ();
97
80
PADDLE_ENFORCE_GE (num_blocks, 2 ,
98
81
" server program should have at least 2 blocks" );
99
82
100
- framework::Executor executor (dev_place);
101
83
std::vector<int > block_list;
102
84
for (size_t blkid = 1 ; blkid < num_blocks; ++blkid) {
103
- if (blkid != static_cast <size_t >(prefetch_block->ID ())) {
104
- block_list.push_back (blkid);
105
- }
85
+ block_list.push_back (blkid);
106
86
}
107
- auto optimize_prepared = executor. Prepare (*program, block_list);
87
+ auto optimize_prepared = executor-> Prepare (*program, block_list);
108
88
// Insert placeholder for block0 which holds current op itself.
109
89
optimize_prepared.insert (
110
90
optimize_prepared.begin (),
111
91
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr ));
112
92
113
- rpc_service_->SetScope (&recv_scope);
114
- rpc_service_->SetDevCtx (&dev_ctx);
115
- // TODO(qiao) set proper fields for table lookup and update
116
- rpc_service_->SetExecutor (&executor);
117
- VLOG (3 ) << " prefetch block id is " << prefetch_block->ID ();
118
- auto prefetch_prepared = executor.Prepare (*program, prefetch_block->ID ());
119
- rpc_service_->SetPrefetchBlkdId (prefetch_block->ID ());
120
- rpc_service_->SetPrefetchPreparedCtx (prefetch_prepared.get ());
121
- prefetch_prepared.release ();
122
- rpc_service_->SetProgram (program);
123
- // start the server listening after all member initialized.
124
- server_thread_.reset (new std::thread (RunServer, rpc_service_));
125
- VLOG (3 ) << " wait server thread to become ready..." ;
126
- sleep (5 );
127
- // Write to a file of server selected port for python use.
128
- std::ofstream port_file;
129
- port_file.open (" /tmp/paddle.selected_port" );
130
- port_file << rpc_service_->GetSelectedPort ();
131
- port_file.close ();
132
-
133
93
bool exit_flag = false ;
134
94
// Record received sparse variables, so that
135
95
// we could reset those after execute optimize program
@@ -170,7 +130,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
170
130
break ;
171
131
}
172
132
173
- // NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
133
+ // NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
174
134
// and this will still work.
175
135
176
136
// The optimize blocks which have the same parent ID would run parallel
@@ -182,16 +142,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
182
142
for (size_t blkid = 2 ; blkid < num_blocks; ++blkid) {
183
143
if (blkid != static_cast <size_t >(prefetch_block->ID ())) {
184
144
if (program->Block (blkid).Parent () != last_parent_blkid) {
185
- ParallelExecuteBlocks (parallel_blkids, & executor, optimize_prepared,
186
- program, & recv_scope);
145
+ ParallelExecuteBlocks (parallel_blkids, executor, optimize_prepared,
146
+ program, recv_scope);
187
147
parallel_blkids.clear ();
188
148
last_parent_blkid = program->Block (blkid).Parent ();
189
149
}
190
150
parallel_blkids.push_back (blkid);
191
151
}
192
152
}
193
- ParallelExecuteBlocks (parallel_blkids, & executor, optimize_prepared,
194
- program, & recv_scope);
153
+ ParallelExecuteBlocks (parallel_blkids, executor, optimize_prepared, program ,
154
+ recv_scope);
195
155
VLOG (2 ) << " run all blocks spent " << detail::GetTimestamp () - ts << " (ms)" ;
196
156
197
157
// Reset the received sparse variables, the sum operator would not
@@ -209,6 +169,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
209
169
} // while(true)
210
170
}
211
171
172
+ void ListenAndServOp::RunImpl (const framework::Scope &scope,
173
+ const platform::Place &dev_place) const {
174
+ platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance ();
175
+ auto &dev_ctx = *pool.Get (dev_place);
176
+ framework::Scope &recv_scope = scope.NewScope ();
177
+
178
+ PADDLE_ENFORCE (!rpc_service_);
179
+ std::string endpoint = Attr<std::string>(" endpoint" );
180
+ rpc_service_.reset (new detail::AsyncGRPCServer (endpoint));
181
+
182
+ auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock );
183
+ auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock );
184
+ auto *program = optimize_block->Program ();
185
+ framework::Executor executor (dev_place);
186
+
187
+ // prepare rpc_service
188
+ rpc_service_->SetScope (&recv_scope);
189
+ rpc_service_->SetDevCtx (&dev_ctx);
190
+ rpc_service_->SetProgram (program);
191
+ rpc_service_->SetExecutor (&executor);
192
+
193
+ // prepare for prefetch
194
+ VLOG (3 ) << " prefetch block id is " << prefetch_block->ID ();
195
+ auto prefetch_prepared = executor.Prepare (*program, prefetch_block->ID ());
196
+ rpc_service_->SetPrefetchPreparedCtx (prefetch_prepared.get ());
197
+ prefetch_prepared.release ();
198
+
199
+ // start the server listening after all member initialized.
200
+ server_thread_.reset (new std::thread (RunServer, rpc_service_));
201
+ VLOG (3 ) << " wait server thread to become ready..." ;
202
+ sleep (5 );
203
+ // Write to a file of server selected port for python use.
204
+ SavePort (rpc_service_);
205
+ RunSyncLoop (&executor, program, &recv_scope, prefetch_block);
206
+ }
207
+
212
208
class ListenAndServOpMaker : public framework ::OpProtoAndCheckerMaker {
213
209
public:
214
210
ListenAndServOpMaker (OpProto *proto, OpAttrChecker *op_checker)
0 commit comments