Skip to content

Commit 6d5144a

Browse files
James Sunfacebook-github-bot
authored andcommitted
sync flush of proc mesh (#823)
Summary: Provide sync flush so it is guaranteed all the flushed logs on the remote procs will be streamed back and flushed on client's stdout/stderr. Differential Revision: D80051803
1 parent 334b974 commit 6d5144a

File tree

5 files changed

+219
-44
lines changed

5 files changed

+219
-44
lines changed

hyperactor_mesh/src/logging.rs

Lines changed: 153 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,15 @@ use chrono::DateTime;
2222
use chrono::Local;
2323
use hyperactor::Actor;
2424
use hyperactor::ActorRef;
25+
use hyperactor::Bind;
2526
use hyperactor::Context;
2627
use hyperactor::HandleClient;
2728
use hyperactor::Handler;
2829
use hyperactor::Instance;
2930
use hyperactor::Named;
31+
use hyperactor::OncePortRef;
3032
use hyperactor::RefClient;
33+
use hyperactor::Unbind;
3134
use hyperactor::channel;
3235
use hyperactor::channel::ChannelAddr;
3336
use hyperactor::channel::ChannelRx;
@@ -39,9 +42,6 @@ use hyperactor::channel::TxStatus;
3942
use hyperactor::clock::Clock;
4043
use hyperactor::clock::RealClock;
4144
use hyperactor::data::Serialized;
42-
use hyperactor::message::Bind;
43-
use hyperactor::message::Bindings;
44-
use hyperactor::message::Unbind;
4545
use hyperactor_telemetry::env;
4646
use hyperactor_telemetry::log_file_path;
4747
use serde::Deserialize;
@@ -235,6 +235,24 @@ impl fmt::Display for Aggregator {
235235
}
236236
}
237237

238+
/// Messages that can be sent to the LogClientActor remotely.
239+
#[derive(
240+
Debug,
241+
Clone,
242+
Serialize,
243+
Deserialize,
244+
Named,
245+
Handler,
246+
HandleClient,
247+
RefClient,
248+
Bind,
249+
Unbind
250+
)]
251+
pub enum LogFlushMessage {
252+
/// Flush the log
253+
ForceSyncFlush {},
254+
}
255+
238256
/// Messages that can be sent to the LogClientActor remotely.
239257
#[derive(
240258
Debug,
@@ -260,7 +278,10 @@ pub enum LogMessage {
260278
},
261279

262280
/// Flush the log
263-
Flush {},
281+
Flush {
282+
/// If true, force a flush sync barrier across all procs
283+
synced: bool,
284+
},
264285
}
265286

266287
/// Messages that can be sent to the LogClient locally.
@@ -279,6 +300,14 @@ pub enum LogClientMessage {
279300
/// The time window in seconds to aggregate logs. If None, aggregation is disabled.
280301
aggregate_window_sec: Option<u64>,
281302
},
303+
304+
/// Synchronously flush all the logs from all the procs. This is for client to call.
305+
StartSyncFlush {
306+
/// Expect these many procs to ack the flush message.
307+
expected_procs: usize,
308+
/// Return once we have received the acks from all the procs
309+
reply: OncePortRef<()>,
310+
},
282311
}
283312

284313
/// Trait for sending logs
@@ -352,7 +381,7 @@ impl LogSender for LocalLogSender {
352381
// send will make sure message is delivered
353382
if TxStatus::Active == *self.status.borrow() {
354383
// Do not use tx.send, it will block the allocator as the child process state is unknown.
355-
self.tx.post(LogMessage::Flush {});
384+
self.tx.post(LogMessage::Flush { synced: false });
356385
} else {
357386
tracing::debug!(
358387
"log sender {} is not active, skip sending flush message",
@@ -547,7 +576,9 @@ impl<T: LogSender + Unpin + 'static, S: io::AsyncWrite + Send + Unpin + 'static>
547576
Named,
548577
Handler,
549578
HandleClient,
550-
RefClient
579+
RefClient,
580+
Bind,
581+
Unbind
551582
)]
552583
pub enum LogForwardMessage {
553584
/// Receive the log from the parent process and forward ti to the client.
@@ -557,18 +588,6 @@ pub enum LogForwardMessage {
557588
SetMode { stream_to_client: bool },
558589
}
559590

560-
impl Bind for LogForwardMessage {
561-
fn bind(&mut self, _bindings: &mut Bindings) -> anyhow::Result<()> {
562-
Ok(())
563-
}
564-
}
565-
566-
impl Unbind for LogForwardMessage {
567-
fn unbind(&self, _bindings: &mut Bindings) -> anyhow::Result<()> {
568-
Ok(())
569-
}
570-
}
571-
572591
/// A log forwarder that receives the log from its parent process and forward it back to the client
573592
#[derive(Debug)]
574593
#[hyperactor::export(
@@ -636,17 +655,28 @@ impl Actor for LogForwardActor {
636655
#[hyperactor::forward(LogForwardMessage)]
637656
impl LogForwardMessageHandler for LogForwardActor {
638657
async fn forward(&mut self, ctx: &Context<Self>) -> Result<(), anyhow::Error> {
639-
if let Ok(LogMessage::Log {
640-
hostname,
641-
pid,
642-
output_target,
643-
payload,
644-
}) = self.rx.recv().await
645-
{
646-
if self.stream_to_client {
647-
self.logging_client_ref
648-
.log(ctx, hostname, pid, output_target, payload)
649-
.await?;
658+
match self.rx.recv().await {
659+
Ok(LogMessage::Flush { synced }) => {
660+
if synced {
661+
self.logging_client_ref.flush(ctx, true).await?;
662+
} else {
663+
// no need to do anything. The previous messages have already been sent to the client.
664+
}
665+
}
666+
Ok(LogMessage::Log {
667+
hostname,
668+
pid,
669+
output_target,
670+
payload,
671+
}) => {
672+
if self.stream_to_client {
673+
self.logging_client_ref
674+
.log(ctx, hostname, pid, output_target, payload)
675+
.await?;
676+
}
677+
}
678+
Err(e) => {
679+
return Err(e.into());
650680
}
651681
}
652682

@@ -685,6 +715,54 @@ fn deserialize_message_lines(
685715
anyhow::bail!("Failed to deserialize message as either String or Vec<u8>")
686716
}
687717

718+
/// An actor that send flush message to the log forwarder actor.
719+
/// The reason we need an extra actor instead of reusing the log forwarder actor
720+
/// is because the log forwarder can be blocked on the rx.recv() that listens on the new log lines.
721+
/// Thus, we need to create anew channel as a tx to send the flush message to the log forwarder
722+
/// So we do not get into a deadlock.
723+
#[derive(Debug)]
724+
#[hyperactor::export(
725+
spawn = true,
726+
handlers = [LogFlushMessage {cast = true}],
727+
)]
728+
pub struct LogFlushActor {
729+
tx: ChannelTx<LogMessage>,
730+
}
731+
732+
#[async_trait]
733+
impl Actor for LogFlushActor {
734+
type Params = ();
735+
736+
async fn new(_: ()) -> Result<Self, anyhow::Error> {
737+
let log_channel: ChannelAddr = match std::env::var(BOOTSTRAP_LOG_CHANNEL) {
738+
Ok(channel) => channel.parse()?,
739+
Err(err) => {
740+
tracing::debug!(
741+
"log forwarder actor failed to read env var {}: {}",
742+
BOOTSTRAP_LOG_CHANNEL,
743+
err
744+
);
745+
// TODO: this should error out; it can only happen with local proc; we need to fix it.
746+
ChannelAddr::any(ChannelTransport::Unix)
747+
}
748+
};
749+
let tx = channel::dial::<LogMessage>(log_channel)?;
750+
751+
Ok(Self { tx })
752+
}
753+
}
754+
755+
#[async_trait]
756+
#[hyperactor::forward(LogFlushMessage)]
757+
impl LogFlushMessageHandler for LogFlushActor {
758+
async fn force_sync_flush(&mut self, _cx: &Context<Self>) -> Result<(), anyhow::Error> {
759+
self.tx
760+
.send(LogMessage::Flush { synced: true })
761+
.await
762+
.map_err(anyhow::Error::from)
763+
}
764+
}
765+
688766
/// A client to receive logs from remote processes
689767
#[derive(Debug)]
690768
#[hyperactor::export(
@@ -696,6 +774,8 @@ pub struct LogClientActor {
696774
aggregators: HashMap<OutputTarget, Aggregator>,
697775
last_flush_time: SystemTime,
698776
next_flush_deadline: Option<SystemTime>,
777+
ongoing_flush_port: Option<OncePortRef<()>>,
778+
unflushed_procs: usize,
699779
}
700780

701781
impl LogClientActor {
@@ -725,6 +805,12 @@ impl LogClientActor {
725805
OutputTarget::Stderr => eprintln!("{}", message),
726806
}
727807
}
808+
809+
fn flush_internal(&mut self) {
810+
self.print_aggregators();
811+
self.last_flush_time = RealClock.system_time_now();
812+
self.next_flush_deadline = None;
813+
}
728814
}
729815

730816
#[async_trait]
@@ -743,6 +829,8 @@ impl Actor for LogClientActor {
743829
aggregators,
744830
last_flush_time: RealClock.system_time_now(),
745831
next_flush_deadline: None,
832+
ongoing_flush_port: None,
833+
unflushed_procs: 0,
746834
})
747835
}
748836
}
@@ -794,20 +882,23 @@ impl LogMessageHandler for LogClientActor {
794882
let new_deadline = self.last_flush_time + Duration::from_secs(window);
795883
let now = RealClock.system_time_now();
796884
if new_deadline <= now {
797-
self.flush(cx).await?;
885+
self.flush_internal();
798886
} else {
799887
let delay = new_deadline.duration_since(now)?;
800888
match self.next_flush_deadline {
801889
None => {
802890
self.next_flush_deadline = Some(new_deadline);
803-
cx.self_message_with_delay(LogMessage::Flush {}, delay)?;
891+
cx.self_message_with_delay(LogMessage::Flush { synced: false }, delay)?;
804892
}
805893
Some(deadline) => {
806894
// Some early log lines have alrady triggered the flush.
807895
if new_deadline < deadline {
808896
// This can happen if the user has adjusted the aggregation window.
809897
self.next_flush_deadline = Some(new_deadline);
810-
cx.self_message_with_delay(LogMessage::Flush {}, delay)?;
898+
cx.self_message_with_delay(
899+
LogMessage::Flush { synced: false },
900+
delay,
901+
)?;
811902
}
812903
}
813904
}
@@ -818,10 +909,21 @@ impl LogMessageHandler for LogClientActor {
818909
Ok(())
819910
}
820911

821-
async fn flush(&mut self, _cx: &Context<Self>) -> Result<(), anyhow::Error> {
822-
self.print_aggregators();
823-
self.last_flush_time = RealClock.system_time_now();
824-
self.next_flush_deadline = None;
912+
async fn flush(&mut self, cx: &Context<Self>, synced: bool) -> Result<(), anyhow::Error> {
913+
if synced {
914+
if self.unflushed_procs == 0 || self.ongoing_flush_port.is_none() {
915+
anyhow::bail!("found no ongoing flush request");
916+
}
917+
self.unflushed_procs -= 1;
918+
if self.unflushed_procs == 0 {
919+
self.flush_internal();
920+
let reply = self.ongoing_flush_port.take().unwrap();
921+
self.ongoing_flush_port = None;
922+
reply.send(cx, ()).map_err(anyhow::Error::from)?;
923+
}
924+
} else {
925+
self.flush_internal();
926+
}
825927

826928
Ok(())
827929
}
@@ -842,6 +944,21 @@ impl LogClientMessageHandler for LogClientActor {
842944
self.aggregate_window_sec = aggregate_window_sec;
843945
Ok(())
844946
}
947+
948+
async fn start_sync_flush(
949+
&mut self,
950+
_cx: &Context<Self>,
951+
expected_procs_flushed: usize,
952+
reply: OncePortRef<()>,
953+
) -> Result<(), anyhow::Error> {
954+
if self.unflushed_procs > 0 || self.ongoing_flush_port.is_some() {
955+
anyhow::bail!("forcing a flush while the ongoing flush has not finished yet");
956+
}
957+
958+
self.ongoing_flush_port = Some(reply.clone());
959+
self.unflushed_procs = expected_procs_flushed;
960+
Ok(())
961+
}
845962
}
846963

847964
#[cfg(test)]

monarch_extension/src/logging.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ use hyperactor_mesh::RootActorMesh;
1313
use hyperactor_mesh::actor_mesh::ActorMesh;
1414
use hyperactor_mesh::logging::LogClientActor;
1515
use hyperactor_mesh::logging::LogClientMessage;
16+
use hyperactor_mesh::logging::LogFlushActor;
17+
use hyperactor_mesh::logging::LogFlushMessage;
1618
use hyperactor_mesh::logging::LogForwardActor;
1719
use hyperactor_mesh::logging::LogForwardMessage;
1820
use hyperactor_mesh::selection::Selection;
@@ -33,6 +35,9 @@ use pyo3::types::PyModule;
3335
pub struct LoggingMeshClient {
3436
// handles remote process log forwarding; no python runtime
3537
forwarder_mesh: SharedCell<RootActorMesh<'static, LogForwardActor>>,
38+
// because forwarder mesh keeps listening to the new coming logs,
39+
// the flush mesh is a way to unblock it from busy waiting a log and do sync flush.
40+
flush_mesh: SharedCell<RootActorMesh<'static, LogFlushActor>>,
3641
// handles python logger; has python runtime
3742
logger_mesh: SharedCell<RootActorMesh<'static, LoggerRuntimeActor>>,
3843
client_actor: ActorHandle<LogClientActor>,
@@ -47,9 +52,11 @@ impl LoggingMeshClient {
4752
let client_actor = proc_mesh.client_proc().spawn("log_client", ()).await?;
4853
let client_actor_ref = client_actor.bind();
4954
let forwarder_mesh = proc_mesh.spawn("log_forwarder", &client_actor_ref).await?;
55+
let flush_mesh = proc_mesh.spawn("log_flusher", &()).await?;
5056
let logger_mesh = proc_mesh.spawn("logger", &()).await?;
5157
Ok(Self {
5258
forwarder_mesh,
59+
flush_mesh,
5360
logger_mesh,
5461
client_actor,
5562
})
@@ -97,6 +104,36 @@ impl LoggingMeshClient {
97104

98105
Ok(())
99106
}
107+
108+
// A sync flush mechanism for the client make sure all the stdout/stderr are streamed back and flushed.
109+
fn flush(&self, proc_mesh: &PyProcMesh) -> PyResult<PyPythonTask> {
110+
let inner_mesh = proc_mesh.try_inner()?;
111+
let (tx, rx) = inner_mesh.client().open_once_port::<()>();
112+
// First initialize a sync flush.
113+
self.client_actor
114+
.send(LogClientMessage::StartSyncFlush {
115+
expected_procs: inner_mesh.shape().slice().len(),
116+
reply: tx.bind(),
117+
})
118+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
119+
120+
// Then ask all the flushers to ask the log forwarders to sync flush
121+
let flush_inner_mesh = self.flush_mesh.borrow().map_err(anyhow::Error::msg)?;
122+
flush_inner_mesh
123+
.cast(
124+
inner_mesh.client(),
125+
Selection::True,
126+
LogFlushMessage::ForceSyncFlush {},
127+
)
128+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
129+
130+
// Finally the forwarder will send sync point back to the client, flush, and return.
131+
PyPythonTask::new(async move {
132+
rx.recv()
133+
.await
134+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
135+
})
136+
}
100137
}
101138

102139
impl Drop for LoggingMeshClient {

python/monarch/_rust_bindings/monarch_extension/logging.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ class LoggingMeshClient:
2121
def set_mode(
2222
self, stream_to_client: bool, aggregate_window_sec: int | None, level: int
2323
) -> None: ...
24+
def flush(self, proc_mesh: ProcMesh) -> PythonTask[None]: ...

0 commit comments

Comments
 (0)