Skip to content

Commit 707f705

Browse files
James Sunfacebook-github-bot
authored andcommitted
poll_flush for log sender (#815)
Summary: Pull Request resolved: #815 This will help to send logs even faster as poll_flush will guarantee the log lines sent to the other side of the channel. Of course, it is not guaranteed still the client will receive the full log lines if it ends fast. Reviewed By: vidhyav Differential Revision: D79978241 fbshipit-source-id: a94b771eedae7a17af007652526945cd1f0a0b05
1 parent 6213478 commit 707f705

File tree

2 files changed

+109
-15
lines changed

2 files changed

+109
-15
lines changed

hyperactor_mesh/src/logging.rs

Lines changed: 108 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ use std::fmt;
1111
use std::path::Path;
1212
use std::path::PathBuf;
1313
use std::pin::Pin;
14-
use std::sync::Arc;
1514
use std::task::Context as TaskContext;
1615
use std::task::Poll;
1716
use std::time::Duration;
@@ -287,6 +286,10 @@ pub enum LogClientMessage {
287286
pub trait LogSender: Send + Sync {
288287
/// Send a log payload in bytes
289288
fn send(&mut self, target: OutputTarget, payload: Vec<u8>) -> anyhow::Result<()>;
289+
290+
/// Flush the log channel, ensuring all messages are delivered
291+
/// Returns when the flush message has been acknowledged
292+
async fn flush(&mut self) -> anyhow::Result<()>;
290293
}
291294

292295
/// Represents the target output stream (stdout or stderr)
@@ -299,11 +302,10 @@ pub enum OutputTarget {
299302
}
300303

301304
/// Write the log to a local unix channel so some actors can listen to it and stream the log back.
302-
#[derive(Clone)]
303305
pub struct LocalLogSender {
304306
hostname: String,
305307
pid: u32,
306-
tx: Arc<ChannelTx<LogMessage>>,
308+
tx: ChannelTx<LogMessage>,
307309
status: Receiver<TxStatus>,
308310
}
309311

@@ -319,30 +321,51 @@ impl LocalLogSender {
319321
Ok(Self {
320322
hostname,
321323
pid,
322-
tx: Arc::new(tx),
324+
tx,
323325
status,
324326
})
325327
}
326328
}
327329

330+
#[async_trait]
328331
impl LogSender for LocalLogSender {
329332
fn send(&mut self, target: OutputTarget, payload: Vec<u8>) -> anyhow::Result<()> {
330333
if TxStatus::Active == *self.status.borrow() {
334+
// post does not guarantee the message to be delivered
331335
self.tx.post(LogMessage::Log {
332336
hostname: self.hostname.clone(),
333337
pid: self.pid,
334338
output_target: target,
335339
payload: Serialized::serialize_anon(&payload)?,
336340
});
337341
} else {
338-
tracing::trace!(
342+
tracing::debug!(
339343
"log sender {} is not active, skip sending log",
340344
self.tx.addr()
341345
)
342346
}
343347

344348
Ok(())
345349
}
350+
351+
async fn flush(&mut self) -> anyhow::Result<()> {
352+
// send will make sure message is delivered
353+
if TxStatus::Active == *self.status.borrow() {
354+
match self.tx.send(LogMessage::Flush {}).await {
355+
Ok(()) => Ok(()),
356+
Err(e) => {
357+
tracing::error!("log sender {} error sending flush message: {}", self.pid, e);
358+
Err(anyhow::anyhow!("error sending flush message: {}", e))
359+
}
360+
}
361+
} else {
362+
tracing::debug!(
363+
"log sender {} is not active, skip sending flush message",
364+
self.tx.addr()
365+
);
366+
Ok(())
367+
}
368+
}
346369
}
347370

348371
/// A custom writer that tees to both stdout/stderr.
@@ -412,13 +435,17 @@ pub fn create_log_writers(
412435
),
413436
anyhow::Error,
414437
> {
415-
let log_sender = LocalLogSender::new(log_channel, pid)?;
416-
417438
// Create LogWriter instances for stdout and stderr using the shared log sender
418-
let stdout_writer =
419-
LogWriter::with_default_writer(local_rank, OutputTarget::Stdout, log_sender.clone())?;
420-
let stderr_writer =
421-
LogWriter::with_default_writer(local_rank, OutputTarget::Stderr, log_sender)?;
439+
let stdout_writer = LogWriter::with_default_writer(
440+
local_rank,
441+
OutputTarget::Stdout,
442+
LocalLogSender::new(log_channel.clone(), pid)?,
443+
)?;
444+
let stderr_writer = LogWriter::with_default_writer(
445+
local_rank,
446+
OutputTarget::Stderr,
447+
LocalLogSender::new(log_channel, pid)?,
448+
)?;
422449

423450
Ok((Box::new(stdout_writer), Box::new(stderr_writer)))
424451
}
@@ -495,7 +522,34 @@ impl<T: LogSender + Unpin + 'static, S: io::AsyncWrite + Send + Unpin + 'static>
495522

496523
fn poll_flush(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Result<(), io::Error>> {
497524
let this = self.get_mut();
498-
Pin::new(&mut this.std_writer).poll_flush(cx)
525+
526+
// First, flush the standard writer
527+
match Pin::new(&mut this.std_writer).poll_flush(cx) {
528+
Poll::Ready(Ok(())) => {
529+
// Now send a Flush message to the other side of the channel.
530+
let mut flush_future = this.log_sender.flush();
531+
match flush_future.as_mut().poll(cx) {
532+
Poll::Ready(Ok(())) => {
533+
// Successfully sent the flush message
534+
Poll::Ready(Ok(()))
535+
}
536+
Poll::Ready(Err(e)) => {
537+
// Error sending the flush message
538+
tracing::error!("error sending flush message: {}", e);
539+
Poll::Ready(Err(io::Error::other(format!(
540+
"error sending flush message: {}",
541+
e
542+
))))
543+
}
544+
Poll::Pending => {
545+
// The future is not ready yet, so we return Pending
546+
// The waker is already registered by polling the future
547+
Poll::Pending
548+
}
549+
}
550+
}
551+
other => other, // Propagate any errors or Pending state from the std_writer flush
552+
}
499553
}
500554

501555
fn poll_shutdown(
@@ -964,14 +1018,19 @@ mod tests {
9641018
// Mock implementation of LogSender for testing
9651019
struct MockLogSender {
9661020
log_sender: mpsc::UnboundedSender<(OutputTarget, String)>, // (output_target, content)
1021+
flush_called: Arc<Mutex<bool>>, // Track if flush was called
9671022
}
9681023

9691024
impl MockLogSender {
9701025
fn new(log_sender: mpsc::UnboundedSender<(OutputTarget, String)>) -> Self {
971-
Self { log_sender }
1026+
Self {
1027+
log_sender,
1028+
flush_called: Arc::new(Mutex::new(false)),
1029+
}
9721030
}
9731031
}
9741032

1033+
#[async_trait]
9751034
impl LogSender for MockLogSender {
9761035
fn send(&mut self, output_target: OutputTarget, payload: Vec<u8>) -> anyhow::Result<()> {
9771036
// For testing purposes, convert to string if it's valid UTF-8
@@ -984,6 +1043,16 @@ mod tests {
9841043
.send((output_target, line))
9851044
.map_err(|e| anyhow::anyhow!("Failed to send log in test: {}", e))
9861045
}
1046+
1047+
async fn flush(&mut self) -> anyhow::Result<()> {
1048+
// Mark that flush was called
1049+
let mut flush_called = self.flush_called.lock().unwrap();
1050+
*flush_called = true;
1051+
1052+
// For testing purposes, just return Ok
1053+
// In a real implementation, this would wait for all messages to be delivered
1054+
Ok(())
1055+
}
9871056
}
9881057

9891058
#[tokio::test]
@@ -1098,6 +1167,32 @@ mod tests {
10981167
// The rest of the content will be replacement characters, but we don't care about the exact representation
10991168
}
11001169

1170+
#[tokio::test]
1171+
async fn test_log_writer_poll_flush() {
1172+
// Create a channel to receive logs
1173+
let (log_sender, _log_receiver) = mpsc::unbounded_channel();
1174+
1175+
// Create a mock log sender that tracks flush calls
1176+
let mock_log_sender = MockLogSender::new(log_sender);
1177+
let log_sender_flush_tracker = mock_log_sender.flush_called.clone();
1178+
1179+
// Create mock writers for stdout and stderr
1180+
let (stdout_mock_writer, _) = MockWriter::new();
1181+
let stdout_writer: Box<dyn io::AsyncWrite + Send + Unpin> = Box::new(stdout_mock_writer);
1182+
1183+
// Create a log writer with the mocks
1184+
let mut writer = LogWriter::new(OutputTarget::Stdout, stdout_writer, mock_log_sender);
1185+
1186+
// Call flush on the writer
1187+
writer.flush().await.unwrap();
1188+
1189+
// Verify that log sender's flush were called
1190+
assert!(
1191+
*log_sender_flush_tracker.lock().unwrap(),
1192+
"LogSender's flush was not called"
1193+
);
1194+
}
1195+
11011196
#[test]
11021197
fn test_string_similarity() {
11031198
// Test exact match

python/tests/python_actor_test_binary.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ async def _flush_logs() -> None:
4040
for _ in range(5):
4141
await am.print.call("has print streaming")
4242

43-
# Sleep a tiny so we allow the logs to stream back to the client
44-
await asyncio.sleep(1)
43+
await pm.stop()
4544

4645

4746
@main.command("flush-logs")

0 commit comments

Comments
 (0)