diff --git a/hyperactor_mesh/src/logging.rs b/hyperactor_mesh/src/logging.rs index 48b6e619e..282ca8593 100644 --- a/hyperactor_mesh/src/logging.rs +++ b/hyperactor_mesh/src/logging.rs @@ -22,12 +22,15 @@ use chrono::DateTime; use chrono::Local; use hyperactor::Actor; use hyperactor::ActorRef; +use hyperactor::Bind; use hyperactor::Context; use hyperactor::HandleClient; use hyperactor::Handler; use hyperactor::Instance; use hyperactor::Named; +use hyperactor::OncePortRef; use hyperactor::RefClient; +use hyperactor::Unbind; use hyperactor::channel; use hyperactor::channel::ChannelAddr; use hyperactor::channel::ChannelRx; @@ -39,9 +42,6 @@ use hyperactor::channel::TxStatus; use hyperactor::clock::Clock; use hyperactor::clock::RealClock; use hyperactor::data::Serialized; -use hyperactor::message::Bind; -use hyperactor::message::Bindings; -use hyperactor::message::Unbind; use hyperactor_telemetry::env; use hyperactor_telemetry::log_file_path; use serde::Deserialize; @@ -235,6 +235,24 @@ impl fmt::Display for Aggregator { } } +/// Messages that can be sent to the LogClientActor remotely. +#[derive( + Debug, + Clone, + Serialize, + Deserialize, + Named, + Handler, + HandleClient, + RefClient, + Bind, + Unbind +)] +pub enum LogFlushMessage { + /// Flush the log + ForceSyncFlush { version: u64 }, +} + /// Messages that can be sent to the LogClientActor remotely. #[derive( Debug, @@ -260,7 +278,11 @@ pub enum LogMessage { }, /// Flush the log - Flush {}, + Flush { + /// Indicate if the current flush is synced or non-synced. + /// If synced, a version number is available. Otherwise, none. + sync_version: Option, + }, } /// Messages that can be sent to the LogClient locally. @@ -279,6 +301,16 @@ pub enum LogClientMessage { /// The time window in seconds to aggregate logs. If None, aggregation is disabled. aggregate_window_sec: Option, }, + + /// Synchronously flush all the logs from all the procs. This is for client to call. + StartSyncFlush { + /// Expect these many procs to ack the flush message. + expected_procs: usize, + /// Return once we have received the acks from all the procs + reply: OncePortRef<()>, + /// Return to the caller the current flush version + version: OncePortRef, + }, } /// Trait for sending logs @@ -352,7 +384,7 @@ impl LogSender for LocalLogSender { // send will make sure message is delivered if TxStatus::Active == *self.status.borrow() { // Do not use tx.send, it will block the allocator as the child process state is unknown. - self.tx.post(LogMessage::Flush {}); + self.tx.post(LogMessage::Flush { sync_version: None }); } else { tracing::debug!( "log sender {} is not active, skip sending flush message", @@ -547,7 +579,9 @@ impl Named, Handler, HandleClient, - RefClient + RefClient, + Bind, + Unbind )] pub enum LogForwardMessage { /// Receive the log from the parent process and forward ti to the client. @@ -557,18 +591,6 @@ pub enum LogForwardMessage { SetMode { stream_to_client: bool }, } -impl Bind for LogForwardMessage { - fn bind(&mut self, _bindings: &mut Bindings) -> anyhow::Result<()> { - Ok(()) - } -} - -impl Unbind for LogForwardMessage { - fn unbind(&self, _bindings: &mut Bindings) -> anyhow::Result<()> { - Ok(()) - } -} - /// A log forwarder that receives the log from its parent process and forward it back to the client #[derive(Debug)] #[hyperactor::export( @@ -636,17 +658,32 @@ impl Actor for LogForwardActor { #[hyperactor::forward(LogForwardMessage)] impl LogForwardMessageHandler for LogForwardActor { async fn forward(&mut self, ctx: &Context) -> Result<(), anyhow::Error> { - if let Ok(LogMessage::Log { - hostname, - pid, - output_target, - payload, - }) = self.rx.recv().await - { - if self.stream_to_client { - self.logging_client_ref - .log(ctx, hostname, pid, output_target, payload) - .await?; + match self.rx.recv().await { + Ok(LogMessage::Flush { sync_version }) => { + match sync_version { + None => { + // no need to do anything. The previous messages have already been sent to the client. + // Client will flush based on its own frequency. + } + version => { + self.logging_client_ref.flush(ctx, version).await?; + } + } + } + Ok(LogMessage::Log { + hostname, + pid, + output_target, + payload, + }) => { + if self.stream_to_client { + self.logging_client_ref + .log(ctx, hostname, pid, output_target, payload) + .await?; + } + } + Err(e) => { + return Err(e.into()); } } @@ -685,6 +722,60 @@ fn deserialize_message_lines( anyhow::bail!("Failed to deserialize message as either String or Vec") } +/// An actor that send flush message to the log forwarder actor. +/// The reason we need an extra actor instead of reusing the log forwarder actor +/// is because the log forwarder can be blocked on the rx.recv() that listens on the new log lines. +/// Thus, we need to create anew channel as a tx to send the flush message to the log forwarder +/// So we do not get into a deadlock. +#[derive(Debug)] +#[hyperactor::export( + spawn = true, + handlers = [LogFlushMessage {cast = true}], +)] +pub struct LogFlushActor { + tx: ChannelTx, +} + +#[async_trait] +impl Actor for LogFlushActor { + type Params = (); + + async fn new(_: ()) -> Result { + let log_channel: ChannelAddr = match std::env::var(BOOTSTRAP_LOG_CHANNEL) { + Ok(channel) => channel.parse()?, + Err(err) => { + tracing::debug!( + "log forwarder actor failed to read env var {}: {}", + BOOTSTRAP_LOG_CHANNEL, + err + ); + // TODO: this should error out; it can only happen with local proc; we need to fix it. + ChannelAddr::any(ChannelTransport::Unix) + } + }; + let tx = channel::dial::(log_channel)?; + + Ok(Self { tx }) + } +} + +#[async_trait] +#[hyperactor::forward(LogFlushMessage)] +impl LogFlushMessageHandler for LogFlushActor { + async fn force_sync_flush( + &mut self, + _cx: &Context, + version: u64, + ) -> Result<(), anyhow::Error> { + self.tx + .send(LogMessage::Flush { + sync_version: Some(version), + }) + .await + .map_err(anyhow::Error::from) + } +} + /// A client to receive logs from remote processes #[derive(Debug)] #[hyperactor::export( @@ -696,6 +787,11 @@ pub struct LogClientActor { aggregators: HashMap, last_flush_time: SystemTime, next_flush_deadline: Option, + + // For flush sync barrier + current_flush_version: u64, + current_flush_port: Option>, + current_unflushed_procs: usize, } impl LogClientActor { @@ -725,6 +821,12 @@ impl LogClientActor { OutputTarget::Stderr => eprintln!("{}", message), } } + + fn flush_internal(&mut self) { + self.print_aggregators(); + self.last_flush_time = RealClock.system_time_now(); + self.next_flush_deadline = None; + } } #[async_trait] @@ -743,6 +845,9 @@ impl Actor for LogClientActor { aggregators, last_flush_time: RealClock.system_time_now(), next_flush_deadline: None, + current_flush_version: 0, + current_flush_port: None, + current_unflushed_procs: 0, }) } } @@ -794,20 +899,26 @@ impl LogMessageHandler for LogClientActor { let new_deadline = self.last_flush_time + Duration::from_secs(window); let now = RealClock.system_time_now(); if new_deadline <= now { - self.flush(cx).await?; + self.flush_internal(); } else { let delay = new_deadline.duration_since(now)?; match self.next_flush_deadline { None => { self.next_flush_deadline = Some(new_deadline); - cx.self_message_with_delay(LogMessage::Flush {}, delay)?; + cx.self_message_with_delay( + LogMessage::Flush { sync_version: None }, + delay, + )?; } Some(deadline) => { // Some early log lines have alrady triggered the flush. if new_deadline < deadline { // This can happen if the user has adjusted the aggregation window. self.next_flush_deadline = Some(new_deadline); - cx.self_message_with_delay(LogMessage::Flush {}, delay)?; + cx.self_message_with_delay( + LogMessage::Flush { sync_version: None }, + delay, + )?; } } } @@ -818,10 +929,45 @@ impl LogMessageHandler for LogClientActor { Ok(()) } - async fn flush(&mut self, _cx: &Context) -> Result<(), anyhow::Error> { - self.print_aggregators(); - self.last_flush_time = RealClock.system_time_now(); - self.next_flush_deadline = None; + async fn flush( + &mut self, + cx: &Context, + sync_version: Option, + ) -> Result<(), anyhow::Error> { + match sync_version { + None => { + self.flush_internal(); + } + Some(version) => { + if version != self.current_flush_version { + tracing::error!( + "found mismatched flush versions: got {}, expect {}; this can happen if some previous flush didn't finish fully", + version, + self.current_flush_version + ); + return Ok(()); + } + + if self.current_unflushed_procs == 0 || self.current_flush_port.is_none() { + // This is a serious issue; it's better to error out. + anyhow::bail!("found no ongoing flush request"); + } + self.current_unflushed_procs -= 1; + + tracing::debug!( + "ack sync flush: version {}; remaining procs: {}", + self.current_flush_version, + self.current_unflushed_procs + ); + + if self.current_unflushed_procs == 0 { + self.flush_internal(); + let reply = self.current_flush_port.take().unwrap(); + self.current_flush_port = None; + reply.send(cx, ()).map_err(anyhow::Error::from)?; + } + } + } Ok(()) } @@ -842,6 +988,34 @@ impl LogClientMessageHandler for LogClientActor { self.aggregate_window_sec = aggregate_window_sec; Ok(()) } + + async fn start_sync_flush( + &mut self, + cx: &Context, + expected_procs_flushed: usize, + reply: OncePortRef<()>, + version: OncePortRef, + ) -> Result<(), anyhow::Error> { + if self.current_unflushed_procs > 0 || self.current_flush_port.is_some() { + tracing::warn!( + "found unfinished ongoing flush: version {}; {} unflushed procs", + self.current_flush_version, + self.current_unflushed_procs, + ); + } + + self.current_flush_version += 1; + tracing::debug!( + "start sync flush with version {}", + self.current_flush_version + ); + self.current_flush_port = Some(reply.clone()); + self.current_unflushed_procs = expected_procs_flushed; + version + .send(cx, self.current_flush_version) + .map_err(anyhow::Error::from)?; + Ok(()) + } } #[cfg(test)] diff --git a/monarch_extension/src/logging.rs b/monarch_extension/src/logging.rs index 9ff8b208b..a155471e0 100644 --- a/monarch_extension/src/logging.rs +++ b/monarch_extension/src/logging.rs @@ -13,6 +13,8 @@ use hyperactor_mesh::RootActorMesh; use hyperactor_mesh::actor_mesh::ActorMesh; use hyperactor_mesh::logging::LogClientActor; use hyperactor_mesh::logging::LogClientMessage; +use hyperactor_mesh::logging::LogFlushActor; +use hyperactor_mesh::logging::LogFlushMessage; use hyperactor_mesh::logging::LogForwardActor; use hyperactor_mesh::logging::LogForwardMessage; use hyperactor_mesh::selection::Selection; @@ -33,11 +35,49 @@ use pyo3::types::PyModule; pub struct LoggingMeshClient { // handles remote process log forwarding; no python runtime forwarder_mesh: SharedCell>, + // because forwarder mesh keeps listening to the new coming logs, + // the flush mesh is a way to unblock it from busy waiting a log and do sync flush. + flush_mesh: SharedCell>, // handles python logger; has python runtime logger_mesh: SharedCell>, client_actor: ActorHandle, } +impl LoggingMeshClient { + async fn flush_internal( + client_actor: ActorHandle, + flush_mesh: SharedCell>, + ) -> Result<(), anyhow::Error> { + let flush_inner_mesh = flush_mesh.borrow().map_err(anyhow::Error::msg)?; + let (reply_tx, reply_rx) = flush_inner_mesh.proc_mesh().client().open_once_port::<()>(); + let (version_tx, version_rx) = flush_inner_mesh + .proc_mesh() + .client() + .open_once_port::(); + + // First initialize a sync flush. + client_actor.send(LogClientMessage::StartSyncFlush { + expected_procs: flush_inner_mesh.proc_mesh().shape().slice().len(), + reply: reply_tx.bind(), + version: version_tx.bind(), + })?; + + let version = version_rx.recv().await?; + + // Then ask all the flushers to ask the log forwarders to sync flush + flush_inner_mesh.cast( + flush_inner_mesh.proc_mesh().client(), + Selection::True, + LogFlushMessage::ForceSyncFlush { version }, + )?; + + // Finally the forwarder will send sync point back to the client, flush, and return. + reply_rx.recv().await?; + + Ok(()) + } +} + #[pymethods] impl LoggingMeshClient { #[staticmethod] @@ -47,9 +87,11 @@ impl LoggingMeshClient { let client_actor = proc_mesh.client_proc().spawn("log_client", ()).await?; let client_actor_ref = client_actor.bind(); let forwarder_mesh = proc_mesh.spawn("log_forwarder", &client_actor_ref).await?; + let flush_mesh = proc_mesh.spawn("log_flusher", &()).await?; let logger_mesh = proc_mesh.spawn("logger", &()).await?; Ok(Self { forwarder_mesh, + flush_mesh, logger_mesh, client_actor, }) @@ -97,6 +139,18 @@ impl LoggingMeshClient { Ok(()) } + + // A sync flush mechanism for the client make sure all the stdout/stderr are streamed back and flushed. + fn flush(&self) -> PyResult { + let flush_mesh = self.flush_mesh.clone(); + let client_actor = self.client_actor.clone(); + + PyPythonTask::new(async move { + Self::flush_internal(client_actor, flush_mesh) + .await + .map_err(|e| PyErr::new::(e.to_string())) + }) + } } impl Drop for LoggingMeshClient { diff --git a/python/monarch/_rust_bindings/monarch_extension/logging.pyi b/python/monarch/_rust_bindings/monarch_extension/logging.pyi index 5d6f11960..fa3d732af 100644 --- a/python/monarch/_rust_bindings/monarch_extension/logging.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/logging.pyi @@ -21,3 +21,4 @@ class LoggingMeshClient: def set_mode( self, stream_to_client: bool, aggregate_window_sec: int | None, level: int ) -> None: ... + def flush(self) -> PythonTask[None]: ... diff --git a/python/tests/python_actor_test_binary.py b/python/tests/python_actor_test_binary.py index 12a10b0f5..9cff72087 100644 --- a/python/tests/python_actor_test_binary.py +++ b/python/tests/python_actor_test_binary.py @@ -10,6 +10,7 @@ import logging import click +from monarch._src.actor.future import Future from monarch.actor import Actor, endpoint, proc_mesh @@ -40,8 +41,10 @@ async def _flush_logs() -> None: for _ in range(5): await am.print.call("has print streaming") - # TODO: will soon be removed by D80051803 - await asyncio.sleep(2) + # TODO: remove this completely once we hook the flush logic upon dropping device_mesh + log_mesh = pm._logging_mesh_client + assert log_mesh is not None + Future(coro=log_mesh.flush().spawn().task()).get() @main.command("flush-logs") diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 58f4f16dc..c81ad126f 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -27,6 +27,7 @@ from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask from monarch._src.actor.actor_mesh import ActorMesh, Channel, Port +from monarch._src.actor.future import Future from monarch.actor import ( Accumulator, @@ -548,8 +549,10 @@ async def test_actor_log_streaming() -> None: await am.print.call("has print streaming too") await am.log.call("has log streaming as level matched") - # Give it some time to reflect and aggregate - await asyncio.sleep(1) + # TODO: remove this completely once we hook the flush logic upon dropping device_mesh + log_mesh = pm._logging_mesh_client + assert log_mesh is not None + Future(coro=log_mesh.flush().spawn().task()).get() # Flush all outputs stdout_file.flush() @@ -664,7 +667,11 @@ async def test_logging_option_defaults() -> None: for _ in range(5): await am.print.call("print streaming") await am.log.call("log streaming") - await asyncio.sleep(4) + + # TODO: remove this completely once we hook the flush logic upon dropping device_mesh + log_mesh = pm._logging_mesh_client + assert log_mesh is not None + Future(coro=log_mesh.flush().spawn().task()).get() # Flush all outputs stdout_file.flush() @@ -790,8 +797,10 @@ async def test_flush_on_disable_aggregation() -> None: for _ in range(5): await am.print.call("single log line") - # Wait a bit to ensure flush completes - await asyncio.sleep(1) + # TODO: remove this completely once we hook the flush logic upon dropping device_mesh + log_mesh = pm._logging_mesh_client + assert log_mesh is not None + Future(coro=log_mesh.flush().spawn().task()).get() # Flush all outputs stdout_file.flush() @@ -835,6 +844,32 @@ async def test_flush_on_disable_aggregation() -> None: pass +@pytest.mark.timeout(120) +async def test_multiple_ongoing_flushes_no_deadlock() -> None: + """ + The goal is to make sure when a user sends multiple sync flushes, we are not deadlocked. + Because now a flush call is purely sync, it is very easy to get into a deadlock. + So we assert the last flush call will not get into such a state. + """ + pm = await proc_mesh(gpus=4) + am = await pm.spawn("printer", Printer) + + # Generate some logs that will be aggregated but not flushed immediately + for _ in range(10): + await am.print.call("aggregated log line") + + log_mesh = pm._logging_mesh_client + assert log_mesh is not None + futures = [] + for _ in range(5): + # FIXME: the order of futures doesn't necessarily mean the order of flushes due to the async nature. + await asyncio.sleep(0.1) + futures.append(Future(coro=log_mesh.flush().spawn().task())) + + # The last flush should not block + futures[-1].get() + + @pytest.mark.timeout(60) async def test_adjust_aggregation_window() -> None: """Test that the flush deadline is updated when the aggregation window is adjusted. @@ -875,8 +910,10 @@ async def test_adjust_aggregation_window() -> None: for _ in range(3): await am.print.call("second batch of logs") - # Wait just enough time for the shorter window to trigger a flush - await asyncio.sleep(1) + # TODO: remove this completely once we hook the flush logic upon dropping device_mesh + log_mesh = pm._logging_mesh_client + assert log_mesh is not None + Future(coro=log_mesh.flush().spawn().task()).get() # Flush all outputs stdout_file.flush()