@@ -135,9 +135,10 @@ struct Child {
135
135
channel : ChannelState ,
136
136
group : monitor:: Group ,
137
137
exit_flag : Option < flag:: Flag > ,
138
- stdout : LogTailer ,
139
- stderr : LogTailer ,
138
+ stdout : Option < LogTailer > ,
139
+ stderr : Option < LogTailer > ,
140
140
stop_reason : Arc < OnceLock < ProcStopReason > > ,
141
+ process_pid : Arc < std:: sync:: Mutex < Option < i32 > > > ,
141
142
}
142
143
143
144
impl Child {
@@ -179,27 +180,23 @@ impl Child {
179
180
stderr_tee,
180
181
) ;
181
182
183
+ let process_pid = Arc :: new ( std:: sync:: Mutex :: new ( Some ( process. id ( ) . unwrap ( ) as i32 ) ) ) ;
184
+
182
185
let child = Self {
183
186
local_rank,
184
187
channel : ChannelState :: NotConnected ,
185
188
group,
186
189
exit_flag : Some ( exit_flag) ,
187
- stdout,
188
- stderr,
190
+ stdout : Some ( stdout ) ,
191
+ stderr : Some ( stderr ) ,
189
192
stop_reason : Arc :: clone ( & stop_reason) ,
193
+ process_pid : process_pid. clone ( ) ,
190
194
} ;
191
195
192
196
let monitor = async move {
193
197
let reason = tokio:: select! {
194
198
_ = handle => {
195
- let Some ( id) = process. id( ) else {
196
- tracing:: error!( "could not get child process id" ) ;
197
- return ProcStopReason :: Unknown ;
198
- } ;
199
- if let Err ( e) = signal:: kill( Pid :: from_raw( id as i32 ) , signal:: SIGTERM ) {
200
- tracing:: error!( "failed to kill child process: {}" , e) ;
201
- return ProcStopReason :: Unknown ;
202
- } ;
199
+ Self :: ensure_killed( process_pid) ;
203
200
Self :: exit_status_to_reason( process. wait( ) . await )
204
201
}
205
202
result = process. wait( ) => {
@@ -214,6 +211,25 @@ impl Child {
214
211
( child, monitor)
215
212
}
216
213
214
+ fn ensure_killed ( pid : Arc < std:: sync:: Mutex < Option < i32 > > > ) {
215
+ match pid. lock ( ) . unwrap ( ) . take ( ) {
216
+ Some ( pid) => {
217
+ if let Err ( e) = signal:: kill ( Pid :: from_raw ( pid) , signal:: SIGTERM ) {
218
+ match e {
219
+ nix:: errno:: Errno :: ESRCH => {
220
+ // Process already gone.
221
+ tracing:: debug!( "pid {} already exited" , pid) ;
222
+ }
223
+ _ => {
224
+ tracing:: error!( "failed to kill {}: {}" , pid, e) ;
225
+ }
226
+ }
227
+ }
228
+ }
229
+ None => ( ) ,
230
+ }
231
+ }
232
+
217
233
fn exit_status_to_reason ( result : io:: Result < ExitStatus > ) -> ProcStopReason {
218
234
match result {
219
235
Ok ( status) if status. success ( ) => ProcStopReason :: Stopped ,
@@ -300,6 +316,12 @@ impl Child {
300
316
}
301
317
}
302
318
319
+ impl Drop for Child {
320
+ fn drop ( & mut self ) {
321
+ Self :: ensure_killed ( self . process_pid . clone ( ) ) ;
322
+ }
323
+ }
324
+
303
325
impl ProcessAlloc {
304
326
// Also implement exit (for graceful exit)
305
327
@@ -371,7 +393,6 @@ impl ProcessAlloc {
371
393
cmd. env ( bootstrap:: BOOTSTRAP_LOG_CHANNEL , log_channel. to_string ( ) ) ;
372
394
cmd. stdout ( Stdio :: piped ( ) ) ;
373
395
cmd. stderr ( Stdio :: piped ( ) ) ;
374
- cmd. kill_on_drop ( true ) ;
375
396
376
397
let proc_id = ProcId :: Ranked ( WorldId ( self . name . to_string ( ) ) , index) ;
377
398
tracing:: debug!( "Spawning process {:?}" , cmd) ;
@@ -472,7 +493,9 @@ impl Alloc for ProcessAlloc {
472
493
} ,
473
494
474
495
Some ( Ok ( ( index, mut reason) ) ) = self . children. join_next( ) => {
475
- let stderr_content = if let Some ( Child { stdout, stderr, ..} ) = self . remove( index) {
496
+ let stderr_content = if let Some ( mut child) = self . remove( index) {
497
+ let stdout = child. stdout. take( ) . unwrap( ) ;
498
+ let stderr = child. stderr. take( ) . unwrap( ) ;
476
499
stdout. abort( ) ;
477
500
stderr. abort( ) ;
478
501
let ( _stdout, _) = stdout. join( ) . await ;
0 commit comments