1
1
use clap:: Parser ;
2
+ use nix:: sys:: signal:: { self , Signal } ;
3
+ use nix:: unistd:: Pid ;
2
4
use std:: env;
3
5
use std:: io:: { BufRead , BufReader , ErrorKind , Write } ;
4
6
use std:: path:: Path ;
5
- use std:: process:: ExitCode ;
7
+ use std:: process:: { Command , ExitCode , Stdio } ;
6
8
use std:: sync:: atomic:: { AtomicBool , Ordering } ;
7
9
use std:: sync:: mpsc:: TryRecvError ;
8
10
use std:: sync:: Arc ;
@@ -13,7 +15,7 @@ use std::time::{Duration, Instant};
13
15
use std:: { fs, io} ;
14
16
use std:: env:: VarError ;
15
17
use std:: ffi:: OsString ;
16
- use subprocess :: { Popen , PopenConfig , PopenError , Redirection } ;
18
+ use std :: os :: unix :: process :: CommandExt ;
17
19
use tracing:: info;
18
20
19
21
// In most cases this gives the best performance for inferencing
@@ -191,8 +193,7 @@ fn main() -> ExitCode {
191
193
Err ( TryRecvError :: Empty ) => {
192
194
sleep ( Duration :: from_millis ( 100 ) ) ;
193
195
}
194
- Ok ( ShardStatus :: Failed ( ( rank, err) ) ) => {
195
- tracing:: error!( "Shard {rank} failed to start:\n {err}" ) ;
196
+ Ok ( ShardStatus :: Failed ) => {
196
197
shutdown_shards ( shutdown, shutdown_receiver) ;
197
198
return ExitCode :: FAILURE ;
198
199
}
@@ -214,7 +215,6 @@ fn main() -> ExitCode {
214
215
// Start webserver
215
216
info ! ( "Starting Router" ) ;
216
217
let mut argv = vec ! [
217
- "text-generation-router" . to_string( ) ,
218
218
"--max-concurrent-requests" . to_string( ) ,
219
219
args. max_concurrent_requests. to_string( ) ,
220
220
"--max-sequence-length" . to_string( ) ,
@@ -271,27 +271,20 @@ fn main() -> ExitCode {
271
271
argv. push ( "--default-include-stop-seqs" . into ( ) ) ;
272
272
}
273
273
274
- let mut webserver = match Popen :: create (
275
- & argv,
276
- PopenConfig {
277
- stdout : Redirection :: Pipe ,
278
- stderr : Redirection :: Pipe ,
279
- // Needed for the shutdown procedure
280
- setpgid : true ,
281
- // env: Some(vec![("RUST_BACKTRACE".into(), "1".into())]),
282
- ..Default :: default ( )
283
- } ,
284
- ) {
274
+ let mut webserver = match Command :: new ( "text-generation-router" )
275
+ . args ( argv)
276
+ . stdout ( Stdio :: piped ( ) )
277
+ . stderr ( Stdio :: piped ( ) )
278
+ . process_group ( 0 )
279
+ . spawn ( )
280
+ {
285
281
Ok ( p) => p,
286
282
Err ( err) => {
287
- tracing:: error!( "Failed to start webserver: {err}" ) ;
288
- if let PopenError :: IoError ( err) = err {
289
- if err. kind ( ) == io:: ErrorKind :: NotFound {
290
- tracing:: error!( "text-generation-router not found in PATH" ) ;
291
- tracing:: error!( "Please install it with `make install-router`" )
292
- }
283
+ if err. kind ( ) == ErrorKind :: NotFound {
284
+ tracing:: error!( "text-generation-router not found in PATH" ) ;
285
+ tracing:: error!( "Please install it with `make install-router`" )
293
286
} else {
294
- tracing:: error!( "{err}" ) ;
287
+ tracing:: error!( "Failed to start webserver: {err}" ) ;
295
288
}
296
289
297
290
shutdown_shards ( shutdown, & shutdown_receiver) ;
@@ -318,28 +311,25 @@ fn main() -> ExitCode {
318
311
let mut exit_code = ExitCode :: SUCCESS ;
319
312
320
313
while running. load ( Ordering :: SeqCst ) {
321
- if let Ok ( ShardStatus :: Failed ( ( rank, err) ) ) = status_receiver. try_recv ( ) {
322
- tracing:: error!( "Shard {rank} failed: {err}" ) ;
314
+ if let Ok ( ShardStatus :: Failed ) = status_receiver. try_recv ( ) {
323
315
exit_code = ExitCode :: FAILURE ;
324
316
break ;
325
317
} ;
326
318
327
- match webserver. poll ( ) {
319
+ match webserver. try_wait ( ) . expect ( "Error polling status of router process" ) {
328
320
Some ( _) => {
329
321
tracing:: error!( "Webserver Crashed" ) ;
330
322
shutdown_shards ( shutdown, & shutdown_receiver) ;
331
323
return ExitCode :: FAILURE ;
332
- }
333
- None => {
334
- sleep ( Duration :: from_millis ( 100 ) ) ;
335
- }
324
+ } ,
325
+ None => sleep ( Duration :: from_millis ( 100 ) ) ,
336
326
} ;
337
327
}
338
328
339
329
// Graceful termination
340
- webserver. terminate ( ) . unwrap ( ) ;
330
+ signal :: kill ( Pid :: from_raw ( webserver. id ( ) as i32 ) , Signal :: SIGTERM ) . unwrap ( ) ;
341
331
info ! ( "Waiting for router to gracefully shutdown" ) ;
342
- webserver. wait_timeout ( Duration :: from_secs ( 120 ) ) . unwrap ( ) ;
332
+ webserver. wait ( ) . unwrap ( ) ;
343
333
info ! ( "Router terminated" ) ;
344
334
shutdown_shards ( shutdown, & shutdown_receiver) ;
345
335
@@ -392,7 +382,7 @@ fn find_num_shards(num_shard: Option<usize>) -> usize {
392
382
#[ derive( Debug ) ]
393
383
enum ShardStatus {
394
384
Ready ,
395
- Failed ( ( usize , String ) ) ,
385
+ Failed ,
396
386
}
397
387
398
388
#[ allow( clippy:: too_many_arguments) ]
@@ -421,11 +411,12 @@ fn shard_manager(
421
411
let uds_string = format ! ( "{uds_path}-{rank}" ) ;
422
412
let uds = Path :: new ( & uds_string) ;
423
413
// Clean previous runs
424
- fs:: remove_file ( uds) . unwrap_or_default ( ) ;
414
+ if uds. exists ( ) {
415
+ fs:: remove_file ( uds) . unwrap_or_default ( ) ;
416
+ }
425
417
426
418
// Process args
427
419
let mut shard_argv = vec ! [
428
- "text-generation-server" . to_string( ) ,
429
420
"serve" . to_string( ) ,
430
421
model_name,
431
422
deployment_framework,
@@ -517,30 +508,24 @@ fn shard_manager(
517
508
518
509
// Start process
519
510
info ! ( "Starting shard {rank}" ) ;
520
- let mut p = match Popen :: create (
521
- & shard_argv,
522
- PopenConfig {
523
- stdout : Redirection :: Pipe ,
524
- stderr : Redirection :: Pipe ,
525
- // Needed for the shutdown procedure
526
- setpgid : true ,
527
- // NCCL env vars
528
- env : Some ( env) ,
529
- ..Default :: default ( )
530
- } ,
531
- ) {
511
+ let mut p = match Command :: new ( "text-generation-server" )
512
+ . args ( shard_argv)
513
+ . envs ( env)
514
+ . stdout ( Stdio :: piped ( ) )
515
+ . stderr ( Stdio :: piped ( ) )
516
+ . process_group ( 0 )
517
+ . spawn ( )
518
+ {
532
519
Ok ( p) => p,
533
520
Err ( err) => {
534
- if let PopenError :: IoError ( ref err) = err {
535
- if err . kind ( ) == io :: ErrorKind :: NotFound {
536
- tracing:: error!( "text-generation-server not found in PATH" ) ;
537
- tracing :: error! ( "Please install it with `make install-server`" )
538
- }
521
+ if err. kind ( ) == ErrorKind :: NotFound {
522
+ tracing :: error! ( "text-generation-server not found in PATH" ) ;
523
+ tracing:: error!( "Please install it with `make install-server`" )
524
+ } else {
525
+ tracing :: error! ( "Shard {rank} failed to start: \n {err}" ) ;
539
526
}
540
- status_sender
541
- . send ( ShardStatus :: Failed ( ( rank, err. to_string ( ) ) ) )
542
- . unwrap ( ) ;
543
- return ;
527
+ status_sender. send ( ShardStatus :: Failed ) . unwrap ( ) ;
528
+ return
544
529
}
545
530
} ;
546
531
@@ -563,20 +548,26 @@ fn shard_manager(
563
548
let mut wait_time = Instant :: now ( ) ;
564
549
loop {
565
550
// Process exited
566
- if let Some ( status) = p. poll ( ) {
567
- // Ensure we finish propagating any final stdout/stderr from the shard
568
- stdout_thread. join ( ) . unwrap_or_default ( ) ;
569
- io:: stdout ( ) . flush ( ) . unwrap_or_default ( ) ;
570
- stderr_thread. join ( ) . unwrap_or_default ( ) ;
571
- io:: stderr ( ) . flush ( ) . unwrap_or_default ( ) ;
572
- status_sender. send ( ShardStatus :: Failed ( ( rank, format ! ( "{status:?}" ) ) ) ) . unwrap ( ) ;
551
+ if let Some ( status) = p. try_wait ( )
552
+ . expect ( & * format ! ( "Error polling status of shard {rank}" ) ) {
553
+ if * shutdown. lock ( ) . unwrap ( ) {
554
+ info ! ( "Shard {rank} terminated" ) ;
555
+ } else {
556
+ tracing:: error!( "Shard {rank} failed: {status:?}" ) ;
557
+ // Ensure we finish propagating any final stdout/stderr from the shard
558
+ stdout_thread. join ( ) . unwrap_or_default ( ) ;
559
+ io:: stdout ( ) . flush ( ) . unwrap_or_default ( ) ;
560
+ stderr_thread. join ( ) . unwrap_or_default ( ) ;
561
+ io:: stderr ( ) . flush ( ) . unwrap_or_default ( ) ;
562
+ status_sender. send ( ShardStatus :: Failed ) . unwrap ( ) ;
563
+ }
573
564
return
574
565
}
575
566
576
567
// We received a shutdown signal
577
568
if * shutdown. lock ( ) . unwrap ( ) {
578
- p. terminate ( ) . unwrap ( ) ;
579
- let _ = p. wait_timeout ( Duration :: from_secs ( 90 ) ) ;
569
+ p. kill ( ) . unwrap ( ) ;
570
+ let _ = p. wait ( ) . unwrap ( ) ;
580
571
info ! ( "Shard {rank} terminated" ) ;
581
572
return
582
573
}
0 commit comments