@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#include < stdio.h> // for removing the port file
16
+ #include < csignal>
17
+ #include < cstdlib>
16
18
#include < fstream>
17
- #include < ostream>
18
19
#include < thread> // NOLINT
19
20
#include < vector>
20
21
@@ -28,7 +29,6 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
28
29
service->RunSyncUpdate ();
29
30
VLOG (4 ) << " RunServer thread end" ;
30
31
}
31
-
32
32
static void split (const std::string &str, char sep,
33
33
std::vector<std::string> *pieces) {
34
34
pieces->clear ();
@@ -59,7 +59,7 @@ static void ParallelExecuteBlocks(
59
59
int run_block = idx; // thread local
60
60
try {
61
61
executor->RunPreparedContext (prepared[run_block].get (), scope);
62
- } catch (std::exception &e) {
62
+ } catch (const std::exception &e) {
63
63
LOG (ERROR) << " run sub program error " << e.what ();
64
64
}
65
65
}));
@@ -75,8 +75,11 @@ ListenAndServOp::ListenAndServOp(const std::string &type,
75
75
const framework::AttributeMap &attrs)
76
76
: OperatorBase(type, inputs, outputs, attrs) {}
77
77
78
+ ListenAndServOp::~ListenAndServOp () { Stop (); }
79
+
78
80
void ListenAndServOp::Stop () {
79
81
rpc_service_->Push (LISTEN_TERMINATE_MESSAGE);
82
+ rpc_service_->ShutDown ();
80
83
server_thread_->join ();
81
84
auto file_path = string::Sprintf (" /tmp/paddle.%d.port" , ::getpid ());
82
85
remove (file_path.c_str ());
@@ -122,7 +125,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
122
125
// Record received sparse variables, so that
123
126
// we could reset those after execute optimize program
124
127
std::vector<framework::Variable *> sparse_vars;
125
- while (!exit_flag) {
128
+ while (!exit_flag && ! SignalHandler::IsProgramExit () ) {
126
129
// Get from multiple trainers, we don't care about the order in which
127
130
// the gradients arrives, just add suffix 0~n and merge the gradient.
128
131
rpc_service_->SetCond (0 );
@@ -187,7 +190,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
187
190
// mini-batch.
188
191
// TODO(Yancey1989): move the reset action into an operator, we couldn't
189
192
// have any hide logic in the operator.
190
- for (auto & var : sparse_vars) {
193
+ for (framework::Variable * var : sparse_vars) {
191
194
var->GetMutable <framework::SelectedRows>()->mutable_rows ()->clear ();
192
195
}
193
196
@@ -204,8 +207,12 @@ static void AsyncUpdateThread(
204
207
framework::Executor *executor,
205
208
framework::ExecutorPrepareContext *prepared) {
206
209
VLOG (3 ) << " update thread for " << var_name << " started" ;
207
- while (!exit_flag) {
210
+ while (!exit_flag && ! SignalHandler::IsProgramExit () ) {
208
211
const detail::ReceivedMessage v = queue->Pop ();
212
+ if (SignalHandler::IsProgramExit ()) {
213
+ VLOG (3 ) << " update thread for " << var_name << " exit" ;
214
+ break ;
215
+ }
209
216
auto recv_var_name = v.first ;
210
217
VLOG (4 ) << " async update " << recv_var_name;
211
218
auto var = v.second ->GetVar ();
@@ -217,7 +224,7 @@ static void AsyncUpdateThread(
217
224
try {
218
225
executor->RunPreparedContext (prepared,
219
226
v.second ->GetMutableLocalScope ());
220
- } catch (std::exception &e) {
227
+ } catch (const std::exception &e) {
221
228
LOG (ERROR) << " run sub program error " << e.what ();
222
229
}
223
230
});
@@ -236,15 +243,19 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
236
243
237
244
auto grad_to_block_id_str =
238
245
Attr<std::vector<std::string>>(" grad_to_block_id" );
239
- for (auto &grad_and_id : grad_to_block_id_str) {
246
+ for (const auto &grad_and_id : grad_to_block_id_str) {
240
247
std::vector<std::string> pieces;
241
248
split (grad_and_id, ' :' , &pieces);
242
249
VLOG (3 ) << " after split, grad = " << pieces[0 ] << " , id=" << pieces[1 ];
243
250
PADDLE_ENFORCE_EQ (pieces.size (), 2 );
244
251
PADDLE_ENFORCE_EQ (grad_to_block_id.count (pieces[0 ]), 0 );
245
252
int block_id = std::stoi (pieces[1 ]);
246
253
grad_to_block_id[pieces[0 ]] = block_id;
247
- grad_to_queue[pieces[0 ]] = std::make_shared<detail::ReceivedQueue>();
254
+ std::shared_ptr<detail::ReceivedQueue> queue =
255
+ std::make_shared<detail::ReceivedQueue>();
256
+ grad_to_queue[pieces[0 ]] = queue;
257
+ // record blocking queue in SignalHandler
258
+ SignalHandler::RegisterBlockingQueue (queue);
248
259
id_to_grad[block_id] = pieces[0 ];
249
260
}
250
261
size_t num_blocks = program->Size ();
@@ -276,9 +287,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
276
287
executor, grad_to_prepared_ctx[grad_name].get ());
277
288
}));
278
289
}
279
-
280
290
VLOG (3 ) << " RunAsyncLoop into while" ;
281
- while (!exit_flag) {
291
+ while (!exit_flag && ! SignalHandler::IsProgramExit () ) {
282
292
const detail::ReceivedMessage v = rpc_service_->Get ();
283
293
auto recv_var_name = v.first ;
284
294
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
@@ -333,6 +343,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
333
343
VLOG (3 ) << " wait server thread to become ready..." ;
334
344
rpc_service_->WaitServerReady ();
335
345
346
+ // register SIGINT(from ctrl+C) and SIGTERM(from kill) signal handlers
347
+ signal (SIGINT, SignalHandler::StopAndExit);
348
+ signal (SIGTERM, SignalHandler::StopAndExit);
349
+
336
350
// Write to a file of server selected port for python use.
337
351
std::string file_path = string::Sprintf (" /tmp/paddle.%d.selected_port" ,
338
352
static_cast <int >(::getpid ()));
@@ -348,12 +362,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
348
362
public:
349
363
void Make () {
350
364
AddInput (" X" , " (Tensor) Variables that server recv." ).AsDuplicable ();
351
- AddComment (R"DOC(
352
- ListenAndServ operator
353
-
354
- This operator will start a RPC server which can receive variables
355
- from send_op and send back variables to recv_op.
356
- )DOC" );
365
+ AddComment (R"DOC( " + "ListenAndServ operator" + "\n" + "This operator" +
366
+ " will start a RPC server which can receive variables from send_op and send" +
367
+ "back variables to recv_op.)DOC" );
357
368
AddAttr<std::string>(" endpoint" ,
358
369
" (string, default 127.0.0.1:6164)"
359
370
" IP address to listen on." )
@@ -374,6 +385,29 @@ from send_op and send back variables to recv_op.
374
385
}
375
386
};
376
387
388
+ bool SignalHandler::program_exit_flag_ = false ;
389
+
390
+ SignalHandler::BlockingQueueSet SignalHandler::blocking_queue_set_{};
391
+
392
+ void SignalHandler::StopAndExit (int signal_num) {
393
+ VLOG (3 ) << " Catch interrupt signal: " << signal_num << " , program will exit" ;
394
+
395
+ program_exit_flag_ = true ;
396
+
397
+ // awake all blocking queues
398
+ for (BlockingQueueSet::iterator iter = blocking_queue_set_.begin ();
399
+ iter != blocking_queue_set_.end (); iter++) {
400
+ iter->get ()->Push (
401
+ std::make_pair (std::string (LISTEN_TERMINATE_MESSAGE), nullptr ));
402
+ }
403
+
404
+ exit (EXIT_SUCCESS);
405
+ }
406
+
407
+ void SignalHandler::RegisterBlockingQueue (BlockingQueue &queue) {
408
+ blocking_queue_set_.insert (queue);
409
+ }
410
+
377
411
} // namespace operators
378
412
} // namespace paddle
379
413
0 commit comments