Skip to content

Commit 9bf59c9

Browse files
James Sunfacebook-github-bot
authored andcommitted
sync flush of proc mesh (meta-pytorch#823)
Summary: Provide sync flush so it is guaranteed all the flushed logs on the remote procs will be streamed back and flushed on client's stdout/stderr. Differential Revision: D80051803
1 parent b5cda4d commit 9bf59c9

File tree

5 files changed

+221
-44
lines changed

5 files changed

+221
-44
lines changed

hyperactor_mesh/src/logging.rs

Lines changed: 155 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,15 @@ use chrono::DateTime;
2222
use chrono::Local;
2323
use hyperactor::Actor;
2424
use hyperactor::ActorRef;
25+
use hyperactor::Bind;
2526
use hyperactor::Context;
2627
use hyperactor::HandleClient;
2728
use hyperactor::Handler;
2829
use hyperactor::Instance;
2930
use hyperactor::Named;
31+
use hyperactor::OncePortRef;
3032
use hyperactor::RefClient;
33+
use hyperactor::Unbind;
3134
use hyperactor::channel;
3235
use hyperactor::channel::ChannelAddr;
3336
use hyperactor::channel::ChannelRx;
@@ -39,9 +42,6 @@ use hyperactor::channel::TxStatus;
3942
use hyperactor::clock::Clock;
4043
use hyperactor::clock::RealClock;
4144
use hyperactor::data::Serialized;
42-
use hyperactor::message::Bind;
43-
use hyperactor::message::Bindings;
44-
use hyperactor::message::Unbind;
4545
use hyperactor_telemetry::env;
4646
use hyperactor_telemetry::log_file_path;
4747
use serde::Deserialize;
@@ -235,6 +235,24 @@ impl fmt::Display for Aggregator {
235235
}
236236
}
237237

238+
/// Messages that can be sent to the LogClientActor remotely.
239+
#[derive(
240+
Debug,
241+
Clone,
242+
Serialize,
243+
Deserialize,
244+
Named,
245+
Handler,
246+
HandleClient,
247+
RefClient,
248+
Bind,
249+
Unbind
250+
)]
251+
pub enum LogFlushMessage {
252+
/// Flush the log
253+
ForceSyncFlush {},
254+
}
255+
238256
/// Messages that can be sent to the LogClientActor remotely.
239257
#[derive(
240258
Debug,
@@ -260,7 +278,10 @@ pub enum LogMessage {
260278
},
261279

262280
/// Flush the log
263-
Flush {},
281+
Flush {
282+
/// If true, force a flush sync barrier across all procs
283+
synced: bool,
284+
},
264285
}
265286

266287
/// Messages that can be sent to the LogClient locally.
@@ -279,6 +300,14 @@ pub enum LogClientMessage {
279300
/// The time window in seconds to aggregate logs. If None, aggregation is disabled.
280301
aggregate_window_sec: Option<u64>,
281302
},
303+
304+
/// Synchronously flush all the logs from all the procs. This is for client to call.
305+
StartSyncFlush {
306+
/// Expect these many procs to ack the flush message.
307+
expected_procs: usize,
308+
/// Return once we have received the acks from all the procs
309+
reply: OncePortRef<()>,
310+
},
282311
}
283312

284313
/// Trait for sending logs
@@ -351,7 +380,9 @@ impl LogSender for LocalLogSender {
351380
async fn flush(&mut self) -> anyhow::Result<()> {
352381
// send will make sure message is delivered
353382
if TxStatus::Active == *self.status.borrow() {
354-
match self.tx.send(LogMessage::Flush {}).await {
383+
// this is just to make sure the log line is sent to the other side of the channel.
384+
// it is up to the forwarder to decide when to flush the log.
385+
match self.tx.send(LogMessage::Flush { synced: false }).await {
355386
Ok(()) => Ok(()),
356387
Err(e) => {
357388
tracing::error!("log sender {} error sending flush message: {}", self.pid, e);
@@ -570,7 +601,9 @@ impl<T: LogSender + Unpin + 'static, S: io::AsyncWrite + Send + Unpin + 'static>
570601
Named,
571602
Handler,
572603
HandleClient,
573-
RefClient
604+
RefClient,
605+
Bind,
606+
Unbind
574607
)]
575608
pub enum LogForwardMessage {
576609
/// Receive the log from the parent process and forward ti to the client.
@@ -580,18 +613,6 @@ pub enum LogForwardMessage {
580613
SetMode { stream_to_client: bool },
581614
}
582615

583-
impl Bind for LogForwardMessage {
584-
fn bind(&mut self, _bindings: &mut Bindings) -> anyhow::Result<()> {
585-
Ok(())
586-
}
587-
}
588-
589-
impl Unbind for LogForwardMessage {
590-
fn unbind(&self, _bindings: &mut Bindings) -> anyhow::Result<()> {
591-
Ok(())
592-
}
593-
}
594-
595616
/// A log forwarder that receives the log from its parent process and forward it back to the client
596617
#[derive(Debug)]
597618
#[hyperactor::export(
@@ -659,17 +680,28 @@ impl Actor for LogForwardActor {
659680
#[hyperactor::forward(LogForwardMessage)]
660681
impl LogForwardMessageHandler for LogForwardActor {
661682
async fn forward(&mut self, ctx: &Context<Self>) -> Result<(), anyhow::Error> {
662-
if let Ok(LogMessage::Log {
663-
hostname,
664-
pid,
665-
output_target,
666-
payload,
667-
}) = self.rx.recv().await
668-
{
669-
if self.stream_to_client {
670-
self.logging_client_ref
671-
.log(ctx, hostname, pid, output_target, payload)
672-
.await?;
683+
match self.rx.recv().await {
684+
Ok(LogMessage::Flush { synced }) => {
685+
if synced {
686+
self.logging_client_ref.flush(ctx, true).await?;
687+
} else {
688+
// no need to do anything. The previous messages have already been sent to the client.
689+
}
690+
}
691+
Ok(LogMessage::Log {
692+
hostname,
693+
pid,
694+
output_target,
695+
payload,
696+
}) => {
697+
if self.stream_to_client {
698+
self.logging_client_ref
699+
.log(ctx, hostname, pid, output_target, payload)
700+
.await?;
701+
}
702+
}
703+
Err(e) => {
704+
return Err(e.into());
673705
}
674706
}
675707

@@ -708,6 +740,54 @@ fn deserialize_message_lines(
708740
anyhow::bail!("Failed to deserialize message as either String or Vec<u8>")
709741
}
710742

743+
/// An actor that send flush message to the log forwarder actor.
744+
/// The reason we need an extra actor instead of reusing the log forwarder actor
745+
/// is because the log forwarder can be blocked on the rx.recv() that listens on the new log lines.
746+
/// Thus, we need to create anew channel as a tx to send the flush message to the log forwarder
747+
/// So we do not get into a deadlock.
748+
#[derive(Debug)]
749+
#[hyperactor::export(
750+
spawn = true,
751+
handlers = [LogFlushMessage {cast = true}],
752+
)]
753+
pub struct LogFlushActor {
754+
tx: ChannelTx<LogMessage>,
755+
}
756+
757+
#[async_trait]
758+
impl Actor for LogFlushActor {
759+
type Params = ();
760+
761+
async fn new(_: ()) -> Result<Self, anyhow::Error> {
762+
let log_channel: ChannelAddr = match std::env::var(BOOTSTRAP_LOG_CHANNEL) {
763+
Ok(channel) => channel.parse()?,
764+
Err(err) => {
765+
tracing::debug!(
766+
"log forwarder actor failed to read env var {}: {}",
767+
BOOTSTRAP_LOG_CHANNEL,
768+
err
769+
);
770+
// TODO: this should error out; it can only happen with local proc; we need to fix it.
771+
ChannelAddr::any(ChannelTransport::Unix)
772+
}
773+
};
774+
let tx = channel::dial::<LogMessage>(log_channel)?;
775+
776+
Ok(Self { tx })
777+
}
778+
}
779+
780+
#[async_trait]
781+
#[hyperactor::forward(LogFlushMessage)]
782+
impl LogFlushMessageHandler for LogFlushActor {
783+
async fn force_sync_flush(&mut self, _cx: &Context<Self>) -> Result<(), anyhow::Error> {
784+
self.tx
785+
.send(LogMessage::Flush { synced: true })
786+
.await
787+
.map_err(anyhow::Error::from)
788+
}
789+
}
790+
711791
/// A client to receive logs from remote processes
712792
#[derive(Debug)]
713793
#[hyperactor::export(
@@ -719,6 +799,8 @@ pub struct LogClientActor {
719799
aggregators: HashMap<OutputTarget, Aggregator>,
720800
last_flush_time: SystemTime,
721801
next_flush_deadline: Option<SystemTime>,
802+
ongoing_flush_port: Option<OncePortRef<()>>,
803+
unflushed_procs: usize,
722804
}
723805

724806
impl LogClientActor {
@@ -748,6 +830,12 @@ impl LogClientActor {
748830
OutputTarget::Stderr => eprintln!("{}", message),
749831
}
750832
}
833+
834+
fn flush_internal(&mut self) {
835+
self.print_aggregators();
836+
self.last_flush_time = RealClock.system_time_now();
837+
self.next_flush_deadline = None;
838+
}
751839
}
752840

753841
#[async_trait]
@@ -766,6 +854,8 @@ impl Actor for LogClientActor {
766854
aggregators,
767855
last_flush_time: RealClock.system_time_now(),
768856
next_flush_deadline: None,
857+
ongoing_flush_port: None,
858+
unflushed_procs: 0,
769859
})
770860
}
771861
}
@@ -817,20 +907,23 @@ impl LogMessageHandler for LogClientActor {
817907
let new_deadline = self.last_flush_time + Duration::from_secs(window);
818908
let now = RealClock.system_time_now();
819909
if new_deadline <= now {
820-
self.flush(cx).await?;
910+
self.flush_internal();
821911
} else {
822912
let delay = new_deadline.duration_since(now)?;
823913
match self.next_flush_deadline {
824914
None => {
825915
self.next_flush_deadline = Some(new_deadline);
826-
cx.self_message_with_delay(LogMessage::Flush {}, delay)?;
916+
cx.self_message_with_delay(LogMessage::Flush { synced: false }, delay)?;
827917
}
828918
Some(deadline) => {
829919
// Some early log lines have alrady triggered the flush.
830920
if new_deadline < deadline {
831921
// This can happen if the user has adjusted the aggregation window.
832922
self.next_flush_deadline = Some(new_deadline);
833-
cx.self_message_with_delay(LogMessage::Flush {}, delay)?;
923+
cx.self_message_with_delay(
924+
LogMessage::Flush { synced: false },
925+
delay,
926+
)?;
834927
}
835928
}
836929
}
@@ -841,10 +934,21 @@ impl LogMessageHandler for LogClientActor {
841934
Ok(())
842935
}
843936

844-
async fn flush(&mut self, _cx: &Context<Self>) -> Result<(), anyhow::Error> {
845-
self.print_aggregators();
846-
self.last_flush_time = RealClock.system_time_now();
847-
self.next_flush_deadline = None;
937+
async fn flush(&mut self, cx: &Context<Self>, synced: bool) -> Result<(), anyhow::Error> {
938+
if synced {
939+
if self.unflushed_procs == 0 || self.ongoing_flush_port.is_none() {
940+
anyhow::bail!("found no ongoing flush request");
941+
}
942+
self.unflushed_procs -= 1;
943+
if self.unflushed_procs == 0 {
944+
self.flush_internal();
945+
let reply = self.ongoing_flush_port.take().unwrap();
946+
self.ongoing_flush_port = None;
947+
reply.send(cx, ()).map_err(anyhow::Error::from)?;
948+
}
949+
} else {
950+
self.flush_internal();
951+
}
848952

849953
Ok(())
850954
}
@@ -865,6 +969,21 @@ impl LogClientMessageHandler for LogClientActor {
865969
self.aggregate_window_sec = aggregate_window_sec;
866970
Ok(())
867971
}
972+
973+
async fn start_sync_flush(
974+
&mut self,
975+
_cx: &Context<Self>,
976+
expected_procs_flushed: usize,
977+
reply: OncePortRef<()>,
978+
) -> Result<(), anyhow::Error> {
979+
if self.unflushed_procs > 0 || self.ongoing_flush_port.is_some() {
980+
anyhow::bail!("forcing a flush while the ongoing flush has not finished yet");
981+
}
982+
983+
self.ongoing_flush_port = Some(reply.clone());
984+
self.unflushed_procs = expected_procs_flushed;
985+
Ok(())
986+
}
868987
}
869988

870989
#[cfg(test)]

monarch_extension/src/logging.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ use hyperactor_mesh::RootActorMesh;
1313
use hyperactor_mesh::actor_mesh::ActorMesh;
1414
use hyperactor_mesh::logging::LogClientActor;
1515
use hyperactor_mesh::logging::LogClientMessage;
16+
use hyperactor_mesh::logging::LogFlushActor;
17+
use hyperactor_mesh::logging::LogFlushMessage;
1618
use hyperactor_mesh::logging::LogForwardActor;
1719
use hyperactor_mesh::logging::LogForwardMessage;
1820
use hyperactor_mesh::selection::Selection;
@@ -33,6 +35,9 @@ use pyo3::types::PyModule;
3335
pub struct LoggingMeshClient {
3436
// handles remote process log forwarding; no python runtime
3537
forwarder_mesh: SharedCell<RootActorMesh<'static, LogForwardActor>>,
38+
// because forwarder mesh keeps listening to the new coming logs,
39+
// the flush mesh is a way to unblock it from busy waiting a log and do sync flush.
40+
flush_mesh: SharedCell<RootActorMesh<'static, LogFlushActor>>,
3641
// handles python logger; has python runtime
3742
logger_mesh: SharedCell<RootActorMesh<'static, LoggerRuntimeActor>>,
3843
client_actor: ActorHandle<LogClientActor>,
@@ -47,9 +52,11 @@ impl LoggingMeshClient {
4752
let client_actor = proc_mesh.client_proc().spawn("log_client", ()).await?;
4853
let client_actor_ref = client_actor.bind();
4954
let forwarder_mesh = proc_mesh.spawn("log_forwarder", &client_actor_ref).await?;
55+
let flush_mesh = proc_mesh.spawn("log_flusher", &()).await?;
5056
let logger_mesh = proc_mesh.spawn("logger", &()).await?;
5157
Ok(Self {
5258
forwarder_mesh,
59+
flush_mesh,
5360
logger_mesh,
5461
client_actor,
5562
})
@@ -97,6 +104,36 @@ impl LoggingMeshClient {
97104

98105
Ok(())
99106
}
107+
108+
// A sync flush mechanism for the client make sure all the stdout/stderr are streamed back and flushed.
109+
fn flush(&self, proc_mesh: &PyProcMesh) -> PyResult<PyPythonTask> {
110+
let inner_mesh = proc_mesh.try_inner()?;
111+
let (tx, rx) = inner_mesh.client().open_once_port::<()>();
112+
// First initialize a sync flush.
113+
self.client_actor
114+
.send(LogClientMessage::StartSyncFlush {
115+
expected_procs: inner_mesh.shape().slice().len(),
116+
reply: tx.bind(),
117+
})
118+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
119+
120+
// Then ask all the flushers to ask the log forwarders to sync flush
121+
let flush_inner_mesh = self.flush_mesh.borrow().map_err(anyhow::Error::msg)?;
122+
flush_inner_mesh
123+
.cast(
124+
inner_mesh.client(),
125+
Selection::True,
126+
LogFlushMessage::ForceSyncFlush {},
127+
)
128+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
129+
130+
// Finally the forwarder will send sync point back to the client, flush, and return.
131+
PyPythonTask::new(async move {
132+
rx.recv()
133+
.await
134+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
135+
})
136+
}
100137
}
101138

102139
impl Drop for LoggingMeshClient {

python/monarch/_rust_bindings/monarch_extension/logging.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ class LoggingMeshClient:
2121
def set_mode(
2222
self, stream_to_client: bool, aggregate_window_sec: int | None, level: int
2323
) -> None: ...
24+
def flush(self, proc_mesh: ProcMesh) -> PythonTask[None]: ...

0 commit comments

Comments
 (0)