@@ -12,20 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
- #include < stdint.h>
16
15
#include < ostream>
16
+ #include < thread>
17
17
18
- #include " paddle/fluid/framework/executor.h"
19
- #include " paddle/fluid/framework/lod_tensor.h"
20
- #include " paddle/fluid/framework/op_registry.h"
21
- #include " paddle/fluid/framework/threadpool.h"
22
- #include " paddle/fluid/operators/detail/grpc_server.h"
18
+ #include " paddle/fluid/operators/listen_and_serv_op.h"
23
19
24
20
namespace paddle {
25
21
namespace operators {
26
22
27
- constexpr char kOptimizeBlock [] = " OptimizeBlock" ;
28
-
29
23
void RunServer (std::shared_ptr<detail::AsyncGRPCServer> service) {
30
24
service->RunSyncUpdate ();
31
25
VLOG (4 ) << " RunServer thread end" ;
@@ -66,143 +60,138 @@ static void ParallelExecuteBlocks(
66
60
for (size_t i = 0 ; i < fs.size (); ++i) fs[i].wait ();
67
61
}
68
62
69
- class ListenAndServOp : public framework ::OperatorBase {
70
- public:
71
- ListenAndServOp (const std::string &type,
72
- const framework::VariableNameMap &inputs,
73
- const framework::VariableNameMap &outputs,
74
- const framework::AttributeMap &attrs)
75
- : OperatorBase(type, inputs, outputs, attrs) {
76
- if (!rpc_service_) {
77
- std::string endpoint = Attr<std::string>(" endpoint" );
78
- rpc_service_.reset (new detail::AsyncGRPCServer (endpoint));
79
- server_thread_.reset (new std::thread (RunServer, rpc_service_));
80
- }
81
- }
63
+ ListenAndServOp::ListenAndServOp (const std::string &type,
64
+ const framework::VariableNameMap &inputs,
65
+ const framework::VariableNameMap &outputs,
66
+ const framework::AttributeMap &attrs)
67
+ : OperatorBase(type, inputs, outputs, attrs) {}
82
68
83
- void Stop () override {
84
- rpc_service_->Push (LISTEN_TERMINATE_MESSAGE);
85
- server_thread_->join ();
69
+ int ListenAndServOp::GetSelectedPort () {
70
+ return rpc_service_->GetSelectedPort ();
71
+ }
72
+
73
+ void ListenAndServOp::Stop () {
74
+ rpc_service_->Push (LISTEN_TERMINATE_MESSAGE);
75
+ server_thread_->join ();
76
+ }
77
+
78
+ void ListenAndServOp::RunImpl (const framework::Scope &scope,
79
+ const platform::Place &dev_place) const {
80
+ platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance ();
81
+ auto &dev_ctx = *pool.Get (dev_place);
82
+ framework::Scope &recv_scope = scope.NewScope ();
83
+
84
+ if (!rpc_service_) {
85
+ std::string endpoint = Attr<std::string>(" endpoint" );
86
+ rpc_service_.reset (new detail::AsyncGRPCServer (endpoint));
86
87
}
87
88
88
- void RunImpl (const framework::Scope &scope,
89
- const platform::Place &dev_place) const override {
90
- platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance ();
91
- auto &dev_ctx = *pool.Get (dev_place);
92
- framework::Scope &recv_scope = scope.NewScope ();
93
-
94
- // FIXME(Yancey1989): initialize rpc server with lazy mode.
95
- rpc_service_->SetScope (&recv_scope);
96
- rpc_service_->SetDevCtx (&dev_ctx);
97
- auto ins = Inputs (" X" );
98
- auto fan_in = Attr<int >(" Fanin" );
99
-
100
- auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock );
101
- auto *program = block->Program ();
102
- size_t num_blocks = program->Size ();
103
- PADDLE_ENFORCE_GE (num_blocks, 2 ,
104
- " server program should have at least 2 blocks" );
105
-
106
- framework::Executor executor (dev_place);
107
- std::vector<int > block_list;
108
- for (size_t blkid = 1 ; blkid < num_blocks; ++blkid)
109
- block_list.push_back (blkid);
110
- auto prepared = executor.Prepare (*program, block_list);
111
- prepared.insert (
112
- prepared.begin (),
113
- std::shared_ptr<framework::ExecutorPrepareContext>(nullptr ));
114
-
115
- // TODO(qiao) set proper fields for table lookup and update
116
- rpc_service_->SetExecutor (&executor);
117
- rpc_service_->SetPrefetchBlkdId (0 );
118
- rpc_service_->SetProgram (program);
119
-
120
- // TODO(typhoonzero): change this to a while_op for every cluster-batch.
121
- bool exit_flag = false ;
122
- // Record received sparse variables, so that
123
- // we could reset those after execute optimize program
124
- std::vector<framework::Variable *> sparse_vars;
125
- while (!exit_flag) {
126
- // Get from multiple trainers, we don't care about the order in which
127
- // the gradients arrives, just add suffix 0~n and merge the gradient.
128
- rpc_service_->SetCond (0 );
129
- size_t recv_var_cnt = 0 ;
130
- int batch_barrier = 0 ;
131
- while (batch_barrier != fan_in) {
132
- const detail::ReceivedMessage v = rpc_service_->Get ();
133
- auto recv_var_name = v.first ;
134
- if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
135
- LOG (INFO) << " received terminate message and exit" ;
136
- exit_flag = true ;
137
- break ;
138
- } else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
139
- VLOG (3 ) << " recv batch barrier message" ;
140
- batch_barrier++;
141
- continue ;
142
- } else {
143
- VLOG (3 ) << " received grad: " << recv_var_name;
144
- recv_var_cnt++;
145
- auto var = v.second ->GetVar ();
146
- if (var == nullptr ) {
147
- LOG (ERROR) << " Can not find server side var: " << recv_var_name;
148
- PADDLE_THROW (" Can not find server side var" );
149
- }
150
- if (var->IsType <framework::SelectedRows>()) {
151
- sparse_vars.push_back (var);
152
- }
153
- }
154
- }
155
- if (exit_flag) {
156
- rpc_service_->SetCond (1 );
157
- rpc_service_->ShutDown ();
89
+ auto ins = Inputs (" X" );
90
+ auto fan_in = Attr<int >(" Fanin" );
91
+ auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock );
92
+ auto *program = block->Program ();
93
+ size_t num_blocks = program->Size ();
94
+ PADDLE_ENFORCE_GE (num_blocks, 2 ,
95
+ " server program should have at least 2 blocks" );
96
+
97
+ framework::Executor executor (dev_place);
98
+ std::vector<int > block_list;
99
+ for (size_t blkid = 1 ; blkid < num_blocks; ++blkid) {
100
+ block_list.push_back (blkid);
101
+ }
102
+ auto prepared = executor.Prepare (*program, block_list);
103
+ // Insert placeholder for block0 which holds current op itself.
104
+ prepared.insert (prepared.begin (),
105
+ std::shared_ptr<framework::ExecutorPrepareContext>(nullptr ));
106
+
107
+ rpc_service_->SetScope (&recv_scope);
108
+ rpc_service_->SetDevCtx (&dev_ctx);
109
+ // TODO(qiao) set proper fields for table lookup and update
110
+ rpc_service_->SetExecutor (&executor);
111
+ rpc_service_->SetPrefetchBlkdId (0 );
112
+ rpc_service_->SetProgram (program);
113
+ // start the server listening after all member initialized.
114
+ server_thread_.reset (new std::thread (RunServer, rpc_service_));
115
+ // FIXME(typhoonzero): do we need to wait until the server port is ready?
116
+ sleep (5 );
117
+
118
+ // TODO(typhoonzero): change this to a while_op for every cluster-batch.
119
+ bool exit_flag = false ;
120
+ // Record received sparse variables, so that
121
+ // we could reset those after execute optimize program
122
+ std::vector<framework::Variable *> sparse_vars;
123
+ while (!exit_flag) {
124
+ // Get from multiple trainers, we don't care about the order in which
125
+ // the gradients arrives, just add suffix 0~n and merge the gradient.
126
+ rpc_service_->SetCond (0 );
127
+ size_t recv_var_cnt = 0 ;
128
+ int batch_barrier = 0 ;
129
+ while (batch_barrier != fan_in) {
130
+ const detail::ReceivedMessage v = rpc_service_->Get ();
131
+ auto recv_var_name = v.first ;
132
+ if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
133
+ LOG (INFO) << " received terminate message and exit" ;
134
+ exit_flag = true ;
158
135
break ;
159
- }
160
-
161
- // NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
162
- // and this will still work.
163
-
164
- // The optimize blocks which have the same parent ID would run parallel
165
- // TODO(Yancey1989): need to use ParallelExecutor for future
166
- int32_t last_parent_blkid = program->Block (1 ).Parent ();
167
- std::vector<size_t > parallel_blkids;
168
- parallel_blkids.push_back (1 );
169
- double ts = detail::GetTimestamp ();
170
- for (size_t blkid = 2 ; blkid < num_blocks; ++blkid) {
171
- if (program->Block (blkid).Parent () != last_parent_blkid) {
172
- for (size_t idx : parallel_blkids) VLOG (3 ) << idx;
173
- ParallelExecuteBlocks (parallel_blkids, &executor, prepared, program,
174
- &recv_scope);
175
- parallel_blkids.clear ();
176
- last_parent_blkid = program->Block (blkid).Parent ();
136
+ } else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
137
+ VLOG (3 ) << " recv batch barrier message" ;
138
+ batch_barrier++;
139
+ continue ;
140
+ } else {
141
+ VLOG (3 ) << " received grad: " << recv_var_name;
142
+ recv_var_cnt++;
143
+ auto var = v.second ->GetVar ();
144
+ if (var == nullptr ) {
145
+ LOG (ERROR) << " Can not find server side var: " << recv_var_name;
146
+ PADDLE_THROW (" Can not find server side var" );
147
+ }
148
+ if (var->IsType <framework::SelectedRows>()) {
149
+ sparse_vars.push_back (var);
177
150
}
178
- parallel_blkids.push_back (blkid);
179
- }
180
- ParallelExecuteBlocks (parallel_blkids, &executor, prepared, program,
181
- &recv_scope);
182
-
183
- VLOG (3 ) << " run all blocks spent " << detail::GetTimestamp () - ts
184
- << " (ms)" ;
185
-
186
- // Reset the received sparse variables, the sum operator would not
187
- // sum the input sparse variables which rows is empty at the next
188
- // mini-batch.
189
- // TODO(Yancey1989): move the reset action into an operator, we couldn't
190
- // have any hide logic in the operator.
191
- for (auto &var : sparse_vars) {
192
- var->GetMutable <framework::SelectedRows>()->mutable_rows ()->clear ();
193
151
}
152
+ }
153
+ if (exit_flag) {
194
154
rpc_service_->SetCond (1 );
195
- // NOTE: does not consider barrier request retry in here, we may use
196
- // global barrier id to resolve this.
197
- rpc_service_->WaitClientGet (fan_in);
198
- sparse_vars.clear ();
199
- } // while(true)
200
- }
155
+ rpc_service_->ShutDown ();
156
+ break ;
157
+ }
201
158
202
- protected:
203
- std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
204
- std::shared_ptr<std::thread> server_thread_;
205
- };
159
+ // NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
160
+ // and this will still work.
161
+
162
+ // The optimize blocks which have the same parent ID would run parallel
163
+ // TODO(Yancey1989): need to use ParallelExecutor for future
164
+ int32_t last_parent_blkid = program->Block (1 ).Parent ();
165
+ std::vector<size_t > parallel_blkids;
166
+ parallel_blkids.push_back (1 );
167
+ double ts = detail::GetTimestamp ();
168
+ for (size_t blkid = 2 ; blkid < num_blocks; ++blkid) {
169
+ if (program->Block (blkid).Parent () != last_parent_blkid) {
170
+ ParallelExecuteBlocks (parallel_blkids, &executor, prepared, program,
171
+ &recv_scope);
172
+ parallel_blkids.clear ();
173
+ last_parent_blkid = program->Block (blkid).Parent ();
174
+ }
175
+ parallel_blkids.push_back (blkid);
176
+ }
177
+ ParallelExecuteBlocks (parallel_blkids, &executor, prepared, program,
178
+ &recv_scope);
179
+ VLOG (2 ) << " run all blocks spent " << detail::GetTimestamp () - ts << " (ms)" ;
180
+
181
+ // Reset the received sparse variables, the sum operator would not
182
+ // sum the input sparse variables which rows is empty at the next
183
+ // mini-batch.
184
+ // TODO(Yancey1989): move the reset action into an operator, we couldn't
185
+ // have any hide logic in the operator.
186
+ for (auto &var : sparse_vars) {
187
+ var->GetMutable <framework::SelectedRows>()->mutable_rows ()->clear ();
188
+ }
189
+ rpc_service_->SetCond (1 );
190
+ // FIXME(typhoonzero): use another condition to sync wait clients get.
191
+ rpc_service_->WaitClientGet (fan_in);
192
+ sparse_vars.clear ();
193
+ } // while(true)
194
+ }
206
195
207
196
class ListenAndServOpMaker : public framework ::OpProtoAndCheckerMaker {
208
197
public:
0 commit comments