Skip to content

Commit b4af8e6

Browse files
highkerfacebook-github-bot
authored andcommitted
log send with poll_flush
Differential Revision: D79978241
1 parent 45e8769 commit b4af8e6

File tree

2 files changed

+109
-15
lines changed

2 files changed

+109
-15
lines changed

hyperactor_mesh/src/logging.rs

Lines changed: 108 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ use std::fmt;
1111
use std::path::Path;
1212
use std::path::PathBuf;
1313
use std::pin::Pin;
14-
use std::sync::Arc;
1514
use std::task::Context as TaskContext;
1615
use std::task::Poll;
1716
use std::time::Duration;
@@ -287,6 +286,10 @@ pub enum LogClientMessage {
287286
pub trait LogSender: Send + Sync {
288287
/// Send a log payload in bytes
289288
fn send(&mut self, target: OutputTarget, payload: Vec<u8>) -> anyhow::Result<()>;
289+
290+
/// Flush the log channel, ensuring all messages are delivered
291+
/// Returns when the flush message has been acknowledged
292+
async fn flush(&mut self) -> anyhow::Result<()>;
290293
}
291294

292295
/// Represents the target output stream (stdout or stderr)
@@ -299,11 +302,10 @@ pub enum OutputTarget {
299302
}
300303

301304
/// Write the log to a local unix channel so some actors can listen to it and stream the log back.
302-
#[derive(Clone)]
303305
pub struct LocalLogSender {
304306
hostname: String,
305307
pid: u32,
306-
tx: Arc<ChannelTx<LogMessage>>,
308+
tx: ChannelTx<LogMessage>,
307309
status: Receiver<TxStatus>,
308310
}
309311

@@ -319,30 +321,51 @@ impl LocalLogSender {
319321
Ok(Self {
320322
hostname,
321323
pid,
322-
tx: Arc::new(tx),
324+
tx,
323325
status,
324326
})
325327
}
326328
}
327329

330+
#[async_trait]
328331
impl LogSender for LocalLogSender {
329332
fn send(&mut self, target: OutputTarget, payload: Vec<u8>) -> anyhow::Result<()> {
330333
if TxStatus::Active == *self.status.borrow() {
334+
// post does not guarantee the message to be delivered
331335
self.tx.post(LogMessage::Log {
332336
hostname: self.hostname.clone(),
333337
pid: self.pid,
334338
output_target: target,
335339
payload: Serialized::serialize_anon(&payload)?,
336340
});
337341
} else {
338-
tracing::trace!(
342+
tracing::debug!(
339343
"log sender {} is not active, skip sending log",
340344
self.tx.addr()
341345
)
342346
}
343347

344348
Ok(())
345349
}
350+
351+
async fn flush(&mut self) -> anyhow::Result<()> {
352+
// send will make sure message is delivered
353+
if TxStatus::Active == *self.status.borrow() {
354+
match self.tx.send(LogMessage::Flush {}).await {
355+
Ok(()) => Ok(()),
356+
Err(e) => {
357+
tracing::error!("log sender {} error sending flush message: {}", self.pid, e);
358+
Err(anyhow::anyhow!("error sending flush message: {}", e))
359+
}
360+
}
361+
} else {
362+
tracing::debug!(
363+
"log sender {} is not active, skip sending flush message",
364+
self.tx.addr()
365+
);
366+
Ok(())
367+
}
368+
}
346369
}
347370

348371
/// A custom writer that tees to both stdout/stderr.
@@ -412,13 +435,17 @@ pub fn create_log_writers(
412435
),
413436
anyhow::Error,
414437
> {
415-
let log_sender = LocalLogSender::new(log_channel, pid)?;
416-
417438
// Create LogWriter instances for stdout and stderr using the shared log sender
418-
let stdout_writer =
419-
LogWriter::with_default_writer(local_rank, OutputTarget::Stdout, log_sender.clone())?;
420-
let stderr_writer =
421-
LogWriter::with_default_writer(local_rank, OutputTarget::Stderr, log_sender)?;
439+
let stdout_writer = LogWriter::with_default_writer(
440+
local_rank,
441+
OutputTarget::Stdout,
442+
LocalLogSender::new(log_channel.clone(), pid)?,
443+
)?;
444+
let stderr_writer = LogWriter::with_default_writer(
445+
local_rank,
446+
OutputTarget::Stderr,
447+
LocalLogSender::new(log_channel, pid)?,
448+
)?;
422449

423450
Ok((Box::new(stdout_writer), Box::new(stderr_writer)))
424451
}
@@ -495,7 +522,34 @@ impl<T: LogSender + Unpin + 'static, S: io::AsyncWrite + Send + Unpin + 'static>
495522

496523
fn poll_flush(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Result<(), io::Error>> {
497524
let this = self.get_mut();
498-
Pin::new(&mut this.std_writer).poll_flush(cx)
525+
526+
// First, flush the standard writer
527+
match Pin::new(&mut this.std_writer).poll_flush(cx) {
528+
Poll::Ready(Ok(())) => {
529+
// Now send a Flush message to the other side of the channel.
530+
let mut flush_future = this.log_sender.flush();
531+
match flush_future.as_mut().poll(cx) {
532+
Poll::Ready(Ok(())) => {
533+
// Successfully sent the flush message
534+
Poll::Ready(Ok(()))
535+
}
536+
Poll::Ready(Err(e)) => {
537+
// Error sending the flush message
538+
tracing::error!("error sending flush message: {}", e);
539+
Poll::Ready(Err(io::Error::other(format!(
540+
"error sending flush message: {}",
541+
e
542+
))))
543+
}
544+
Poll::Pending => {
545+
// The future is not ready yet, so we return Pending
546+
// The waker is already registered by polling the future
547+
Poll::Pending
548+
}
549+
}
550+
}
551+
other => other, // Propagate any errors or Pending state from the std_writer flush
552+
}
499553
}
500554

501555
fn poll_shutdown(
@@ -964,14 +1018,19 @@ mod tests {
9641018
// Mock implementation of LogSender for testing
9651019
struct MockLogSender {
9661020
log_sender: mpsc::UnboundedSender<(OutputTarget, String)>, // (output_target, content)
1021+
flush_called: Arc<Mutex<bool>>, // Track if flush was called
9671022
}
9681023

9691024
impl MockLogSender {
9701025
fn new(log_sender: mpsc::UnboundedSender<(OutputTarget, String)>) -> Self {
971-
Self { log_sender }
1026+
Self {
1027+
log_sender,
1028+
flush_called: Arc::new(Mutex::new(false)),
1029+
}
9721030
}
9731031
}
9741032

1033+
#[async_trait]
9751034
impl LogSender for MockLogSender {
9761035
fn send(&mut self, output_target: OutputTarget, payload: Vec<u8>) -> anyhow::Result<()> {
9771036
// For testing purposes, convert to string if it's valid UTF-8
@@ -984,6 +1043,16 @@ mod tests {
9841043
.send((output_target, line))
9851044
.map_err(|e| anyhow::anyhow!("Failed to send log in test: {}", e))
9861045
}
1046+
1047+
async fn flush(&mut self) -> anyhow::Result<()> {
1048+
// Mark that flush was called
1049+
let mut flush_called = self.flush_called.lock().unwrap();
1050+
*flush_called = true;
1051+
1052+
// For testing purposes, just return Ok
1053+
// In a real implementation, this would wait for all messages to be delivered
1054+
Ok(())
1055+
}
9871056
}
9881057

9891058
#[tokio::test]
@@ -1098,6 +1167,32 @@ mod tests {
10981167
// The rest of the content will be replacement characters, but we don't care about the exact representation
10991168
}
11001169

1170+
#[tokio::test]
1171+
async fn test_log_writer_poll_flush() {
1172+
// Create a channel to receive logs
1173+
let (log_sender, _log_receiver) = mpsc::unbounded_channel();
1174+
1175+
// Create a mock log sender that tracks flush calls
1176+
let mock_log_sender = MockLogSender::new(log_sender);
1177+
let log_sender_flush_tracker = mock_log_sender.flush_called.clone();
1178+
1179+
// Create mock writers for stdout and stderr
1180+
let (stdout_mock_writer, _) = MockWriter::new();
1181+
let stdout_writer: Box<dyn io::AsyncWrite + Send + Unpin> = Box::new(stdout_mock_writer);
1182+
1183+
// Create a log writer with the mocks
1184+
let mut writer = LogWriter::new(OutputTarget::Stdout, stdout_writer, mock_log_sender);
1185+
1186+
// Call flush on the writer
1187+
writer.flush().await.unwrap();
1188+
1189+
// Verify that log sender's flush were called
1190+
assert!(
1191+
*log_sender_flush_tracker.lock().unwrap(),
1192+
"LogSender's flush was not called"
1193+
);
1194+
}
1195+
11011196
#[test]
11021197
fn test_string_similarity() {
11031198
// Test exact match

python/tests/python_actor_test_binary.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ async def _flush_logs() -> None:
4040
for _ in range(5):
4141
await am.print.call("has print streaming")
4242

43-
# Sleep a tiny so we allow the logs to stream back to the client
44-
await asyncio.sleep(1)
43+
await pm.stop()
4544

4645

4746
@main.command("flush-logs")

0 commit comments

Comments
 (0)