@@ -11,7 +11,6 @@ use std::fmt;
11
11
use std:: path:: Path ;
12
12
use std:: path:: PathBuf ;
13
13
use std:: pin:: Pin ;
14
- use std:: sync:: Arc ;
15
14
use std:: task:: Context as TaskContext ;
16
15
use std:: task:: Poll ;
17
16
use std:: time:: Duration ;
@@ -287,6 +286,10 @@ pub enum LogClientMessage {
287
286
pub trait LogSender : Send + Sync {
288
287
/// Send a log payload in bytes
289
288
fn send ( & mut self , target : OutputTarget , payload : Vec < u8 > ) -> anyhow:: Result < ( ) > ;
289
+
290
+ /// Flush the log channel, ensuring all messages are delivered
291
+ /// Returns when the flush message has been acknowledged
292
+ async fn flush ( & mut self ) -> anyhow:: Result < ( ) > ;
290
293
}
291
294
292
295
/// Represents the target output stream (stdout or stderr)
@@ -299,11 +302,10 @@ pub enum OutputTarget {
299
302
}
300
303
301
304
/// Write the log to a local unix channel so some actors can listen to it and stream the log back.
302
- #[ derive( Clone ) ]
303
305
pub struct LocalLogSender {
304
306
hostname : String ,
305
307
pid : u32 ,
306
- tx : Arc < ChannelTx < LogMessage > > ,
308
+ tx : ChannelTx < LogMessage > ,
307
309
status : Receiver < TxStatus > ,
308
310
}
309
311
@@ -319,30 +321,51 @@ impl LocalLogSender {
319
321
Ok ( Self {
320
322
hostname,
321
323
pid,
322
- tx : Arc :: new ( tx ) ,
324
+ tx,
323
325
status,
324
326
} )
325
327
}
326
328
}
327
329
330
+ #[ async_trait]
328
331
impl LogSender for LocalLogSender {
329
332
fn send ( & mut self , target : OutputTarget , payload : Vec < u8 > ) -> anyhow:: Result < ( ) > {
330
333
if TxStatus :: Active == * self . status . borrow ( ) {
334
+ // post does not guarantee the message to be delivered
331
335
self . tx . post ( LogMessage :: Log {
332
336
hostname : self . hostname . clone ( ) ,
333
337
pid : self . pid ,
334
338
output_target : target,
335
339
payload : Serialized :: serialize_anon ( & payload) ?,
336
340
} ) ;
337
341
} else {
338
- tracing:: trace !(
342
+ tracing:: debug !(
339
343
"log sender {} is not active, skip sending log" ,
340
344
self . tx. addr( )
341
345
)
342
346
}
343
347
344
348
Ok ( ( ) )
345
349
}
350
+
351
+ async fn flush ( & mut self ) -> anyhow:: Result < ( ) > {
352
+ // send will make sure message is delivered
353
+ if TxStatus :: Active == * self . status . borrow ( ) {
354
+ match self . tx . send ( LogMessage :: Flush { } ) . await {
355
+ Ok ( ( ) ) => Ok ( ( ) ) ,
356
+ Err ( e) => {
357
+ tracing:: error!( "log sender {} error sending flush message: {}" , self . pid, e) ;
358
+ Err ( anyhow:: anyhow!( "error sending flush message: {}" , e) )
359
+ }
360
+ }
361
+ } else {
362
+ tracing:: debug!(
363
+ "log sender {} is not active, skip sending flush message" ,
364
+ self . tx. addr( )
365
+ ) ;
366
+ Ok ( ( ) )
367
+ }
368
+ }
346
369
}
347
370
348
371
/// A custom writer that tees to both stdout/stderr.
@@ -412,13 +435,17 @@ pub fn create_log_writers(
412
435
) ,
413
436
anyhow:: Error ,
414
437
> {
415
- let log_sender = LocalLogSender :: new ( log_channel, pid) ?;
416
-
417
438
// Create LogWriter instances for stdout and stderr using the shared log sender
418
- let stdout_writer =
419
- LogWriter :: with_default_writer ( local_rank, OutputTarget :: Stdout , log_sender. clone ( ) ) ?;
420
- let stderr_writer =
421
- LogWriter :: with_default_writer ( local_rank, OutputTarget :: Stderr , log_sender) ?;
439
+ let stdout_writer = LogWriter :: with_default_writer (
440
+ local_rank,
441
+ OutputTarget :: Stdout ,
442
+ LocalLogSender :: new ( log_channel. clone ( ) , pid) ?,
443
+ ) ?;
444
+ let stderr_writer = LogWriter :: with_default_writer (
445
+ local_rank,
446
+ OutputTarget :: Stderr ,
447
+ LocalLogSender :: new ( log_channel, pid) ?,
448
+ ) ?;
422
449
423
450
Ok ( ( Box :: new ( stdout_writer) , Box :: new ( stderr_writer) ) )
424
451
}
@@ -495,7 +522,34 @@ impl<T: LogSender + Unpin + 'static, S: io::AsyncWrite + Send + Unpin + 'static>
495
522
496
523
fn poll_flush ( self : Pin < & mut Self > , cx : & mut TaskContext < ' _ > ) -> Poll < Result < ( ) , io:: Error > > {
497
524
let this = self . get_mut ( ) ;
498
- Pin :: new ( & mut this. std_writer ) . poll_flush ( cx)
525
+
526
+ // First, flush the standard writer
527
+ match Pin :: new ( & mut this. std_writer ) . poll_flush ( cx) {
528
+ Poll :: Ready ( Ok ( ( ) ) ) => {
529
+ // Now send a Flush message to the other side of the channel.
530
+ let mut flush_future = this. log_sender . flush ( ) ;
531
+ match flush_future. as_mut ( ) . poll ( cx) {
532
+ Poll :: Ready ( Ok ( ( ) ) ) => {
533
+ // Successfully sent the flush message
534
+ Poll :: Ready ( Ok ( ( ) ) )
535
+ }
536
+ Poll :: Ready ( Err ( e) ) => {
537
+ // Error sending the flush message
538
+ tracing:: error!( "error sending flush message: {}" , e) ;
539
+ Poll :: Ready ( Err ( io:: Error :: other ( format ! (
540
+ "error sending flush message: {}" ,
541
+ e
542
+ ) ) ) )
543
+ }
544
+ Poll :: Pending => {
545
+ // The future is not ready yet, so we return Pending
546
+ // The waker is already registered by polling the future
547
+ Poll :: Pending
548
+ }
549
+ }
550
+ }
551
+ other => other, // Propagate any errors or Pending state from the std_writer flush
552
+ }
499
553
}
500
554
501
555
fn poll_shutdown (
@@ -964,14 +1018,19 @@ mod tests {
964
1018
// Mock implementation of LogSender for testing
965
1019
struct MockLogSender {
966
1020
log_sender : mpsc:: UnboundedSender < ( OutputTarget , String ) > , // (output_target, content)
1021
+ flush_called : Arc < Mutex < bool > > , // Track if flush was called
967
1022
}
968
1023
969
1024
impl MockLogSender {
970
1025
fn new ( log_sender : mpsc:: UnboundedSender < ( OutputTarget , String ) > ) -> Self {
971
- Self { log_sender }
1026
+ Self {
1027
+ log_sender,
1028
+ flush_called : Arc :: new ( Mutex :: new ( false ) ) ,
1029
+ }
972
1030
}
973
1031
}
974
1032
1033
+ #[ async_trait]
975
1034
impl LogSender for MockLogSender {
976
1035
fn send ( & mut self , output_target : OutputTarget , payload : Vec < u8 > ) -> anyhow:: Result < ( ) > {
977
1036
// For testing purposes, convert to string if it's valid UTF-8
@@ -984,6 +1043,16 @@ mod tests {
984
1043
. send ( ( output_target, line) )
985
1044
. map_err ( |e| anyhow:: anyhow!( "Failed to send log in test: {}" , e) )
986
1045
}
1046
+
1047
+ async fn flush ( & mut self ) -> anyhow:: Result < ( ) > {
1048
+ // Mark that flush was called
1049
+ let mut flush_called = self . flush_called . lock ( ) . unwrap ( ) ;
1050
+ * flush_called = true ;
1051
+
1052
+ // For testing purposes, just return Ok
1053
+ // In a real implementation, this would wait for all messages to be delivered
1054
+ Ok ( ( ) )
1055
+ }
987
1056
}
988
1057
989
1058
#[ tokio:: test]
@@ -1098,6 +1167,32 @@ mod tests {
1098
1167
// The rest of the content will be replacement characters, but we don't care about the exact representation
1099
1168
}
1100
1169
1170
+ #[ tokio:: test]
1171
+ async fn test_log_writer_poll_flush ( ) {
1172
+ // Create a channel to receive logs
1173
+ let ( log_sender, _log_receiver) = mpsc:: unbounded_channel ( ) ;
1174
+
1175
+ // Create a mock log sender that tracks flush calls
1176
+ let mock_log_sender = MockLogSender :: new ( log_sender) ;
1177
+ let log_sender_flush_tracker = mock_log_sender. flush_called . clone ( ) ;
1178
+
1179
+ // Create mock writers for stdout and stderr
1180
+ let ( stdout_mock_writer, _) = MockWriter :: new ( ) ;
1181
+ let stdout_writer: Box < dyn io:: AsyncWrite + Send + Unpin > = Box :: new ( stdout_mock_writer) ;
1182
+
1183
+ // Create a log writer with the mocks
1184
+ let mut writer = LogWriter :: new ( OutputTarget :: Stdout , stdout_writer, mock_log_sender) ;
1185
+
1186
+ // Call flush on the writer
1187
+ writer. flush ( ) . await . unwrap ( ) ;
1188
+
1189
+ // Verify that log sender's flush were called
1190
+ assert ! (
1191
+ * log_sender_flush_tracker. lock( ) . unwrap( ) ,
1192
+ "LogSender's flush was not called"
1193
+ ) ;
1194
+ }
1195
+
1101
1196
#[ test]
1102
1197
fn test_string_similarity ( ) {
1103
1198
// Test exact match
0 commit comments