diff --git a/hyperactor_mesh/src/logging.rs b/hyperactor_mesh/src/logging.rs index 496f58144..950e0af20 100644 --- a/hyperactor_mesh/src/logging.rs +++ b/hyperactor_mesh/src/logging.rs @@ -11,6 +11,7 @@ use std::fmt; use std::path::Path; use std::path::PathBuf; use std::pin::Pin; +use std::sync::Arc; use std::task::Context as TaskContext; use std::task::Poll; use std::time::Duration; @@ -22,12 +23,15 @@ use chrono::DateTime; use chrono::Local; use hyperactor::Actor; use hyperactor::ActorRef; +use hyperactor::Bind; use hyperactor::Context; use hyperactor::HandleClient; use hyperactor::Handler; use hyperactor::Instance; use hyperactor::Named; +use hyperactor::OncePortRef; use hyperactor::RefClient; +use hyperactor::Unbind; use hyperactor::channel; use hyperactor::channel::ChannelAddr; use hyperactor::channel::ChannelRx; @@ -39,14 +43,12 @@ use hyperactor::channel::TxStatus; use hyperactor::clock::Clock; use hyperactor::clock::RealClock; use hyperactor::data::Serialized; -use hyperactor::message::Bind; -use hyperactor::message::Bindings; -use hyperactor::message::Unbind; use hyperactor_telemetry::env; use hyperactor_telemetry::log_file_path; use serde::Deserialize; use serde::Serialize; use tokio::io; +use tokio::sync::Mutex; use tokio::sync::watch::Receiver; use crate::bootstrap::BOOTSTRAP_LOG_CHANNEL; @@ -260,7 +262,11 @@ pub enum LogMessage { }, /// Flush the log - Flush {}, + Flush { + /// Indicate if the current flush is synced or non-synced. + /// If synced, a version number is available. Otherwise, none. + sync_version: Option, + }, } /// Messages that can be sent to the LogClient locally. @@ -279,6 +285,16 @@ pub enum LogClientMessage { /// The time window in seconds to aggregate logs. If None, aggregation is disabled. aggregate_window_sec: Option, }, + + /// Synchronously flush all the logs from all the procs. This is for client to call. + StartSyncFlush { + /// Expect these many procs to ack the flush message. + expected_procs: usize, + /// Return once we have received the acks from all the procs + reply: OncePortRef<()>, + /// Return to the caller the current flush version + version: OncePortRef, + }, } /// Trait for sending logs @@ -352,7 +368,7 @@ impl LogSender for LocalLogSender { // send will make sure message is delivered if TxStatus::Active == *self.status.borrow() { // Do not use tx.send, it will block the allocator as the child process state is unknown. - self.tx.post(LogMessage::Flush {}); + self.tx.post(LogMessage::Flush { sync_version: None }); } else { tracing::debug!( "log sender {} is not active, skip sending flush message", @@ -558,7 +574,9 @@ impl Named, Handler, HandleClient, - RefClient + RefClient, + Bind, + Unbind )] pub enum LogForwardMessage { /// Receive the log from the parent process and forward ti to the client. @@ -566,18 +584,9 @@ pub enum LogForwardMessage { /// If to stream the log back to the client. SetMode { stream_to_client: bool }, -} -impl Bind for LogForwardMessage { - fn bind(&mut self, _bindings: &mut Bindings) -> anyhow::Result<()> { - Ok(()) - } -} - -impl Unbind for LogForwardMessage { - fn unbind(&self, _bindings: &mut Bindings) -> anyhow::Result<()> { - Ok(()) - } + /// Flush the log with a version number. + ForceSyncFlush { version: u64 }, } /// A log forwarder that receives the log from its parent process and forward it back to the client @@ -588,6 +597,8 @@ impl Unbind for LogForwardMessage { )] pub struct LogForwardActor { rx: ChannelRx, + flush_tx: Arc>>, + next_flush_deadline: SystemTime, logging_client_ref: ActorRef, stream_to_client: bool, } @@ -630,8 +641,15 @@ impl Actor for LogForwardActor { .1 } }; + + // Dial the same channel to send flush message to drain the log queue. + let flush_tx = Arc::new(Mutex::new(channel::dial::(log_channel)?)); + let now = RealClock.system_time_now(); + Ok(Self { rx, + flush_tx, + next_flush_deadline: now, logging_client_ref, stream_to_client: true, }) @@ -639,6 +657,13 @@ impl Actor for LogForwardActor { async fn init(&mut self, this: &Instance) -> Result<(), anyhow::Error> { this.self_message_with_delay(LogForwardMessage::Forward {}, Duration::from_secs(0))?; + + // Make sure we start the flush loop periodically so the log channel will not deadlock. + self.flush_tx + .lock() + .await + .send(LogMessage::Flush { sync_version: None }) + .await?; Ok(()) } } @@ -647,17 +672,48 @@ impl Actor for LogForwardActor { #[hyperactor::forward(LogForwardMessage)] impl LogForwardMessageHandler for LogForwardActor { async fn forward(&mut self, ctx: &Context) -> Result<(), anyhow::Error> { - if let Ok(LogMessage::Log { - hostname, - pid, - output_target, - payload, - }) = self.rx.recv().await - { - if self.stream_to_client { - self.logging_client_ref - .log(ctx, hostname, pid, output_target, payload) - .await?; + match self.rx.recv().await { + Ok(LogMessage::Flush { sync_version }) => { + let now = RealClock.system_time_now(); + match sync_version { + None => { + // Schedule another flush to keep the log channel from deadlocking. + let delay = Duration::from_secs(1); + if now >= self.next_flush_deadline { + self.next_flush_deadline = now + delay; + let flush_tx = self.flush_tx.clone(); + tokio::spawn(async move { + RealClock.sleep(delay).await; + if let Err(e) = flush_tx + .lock() + .await + .send(LogMessage::Flush { sync_version: None }) + .await + { + tracing::error!("failed to send flush message: {}", e); + } + }); + } + } + version => { + self.logging_client_ref.flush(ctx, version).await?; + } + } + } + Ok(LogMessage::Log { + hostname, + pid, + output_target, + payload, + }) => { + if self.stream_to_client { + self.logging_client_ref + .log(ctx, hostname, pid, output_target, payload) + .await?; + } + } + Err(e) => { + return Err(e.into()); } } @@ -675,6 +731,21 @@ impl LogForwardMessageHandler for LogForwardActor { self.stream_to_client = stream_to_client; Ok(()) } + + async fn force_sync_flush( + &mut self, + _cx: &Context, + version: u64, + ) -> Result<(), anyhow::Error> { + self.flush_tx + .lock() + .await + .send(LogMessage::Flush { + sync_version: Some(version), + }) + .await + .map_err(anyhow::Error::from) + } } /// Deserialize a serialized message and split it into UTF-8 lines @@ -707,6 +778,11 @@ pub struct LogClientActor { aggregators: HashMap, last_flush_time: SystemTime, next_flush_deadline: Option, + + // For flush sync barrier + current_flush_version: u64, + current_flush_port: Option>, + current_unflushed_procs: usize, } impl LogClientActor { @@ -736,6 +812,12 @@ impl LogClientActor { OutputTarget::Stderr => eprintln!("{}", message), } } + + fn flush_internal(&mut self) { + self.print_aggregators(); + self.last_flush_time = RealClock.system_time_now(); + self.next_flush_deadline = None; + } } #[async_trait] @@ -754,6 +836,9 @@ impl Actor for LogClientActor { aggregators, last_flush_time: RealClock.system_time_now(), next_flush_deadline: None, + current_flush_version: 0, + current_flush_port: None, + current_unflushed_procs: 0, }) } } @@ -805,20 +890,26 @@ impl LogMessageHandler for LogClientActor { let new_deadline = self.last_flush_time + Duration::from_secs(window); let now = RealClock.system_time_now(); if new_deadline <= now { - self.flush(cx).await?; + self.flush_internal(); } else { let delay = new_deadline.duration_since(now)?; match self.next_flush_deadline { None => { self.next_flush_deadline = Some(new_deadline); - cx.self_message_with_delay(LogMessage::Flush {}, delay)?; + cx.self_message_with_delay( + LogMessage::Flush { sync_version: None }, + delay, + )?; } Some(deadline) => { // Some early log lines have alrady triggered the flush. if new_deadline < deadline { // This can happen if the user has adjusted the aggregation window. self.next_flush_deadline = Some(new_deadline); - cx.self_message_with_delay(LogMessage::Flush {}, delay)?; + cx.self_message_with_delay( + LogMessage::Flush { sync_version: None }, + delay, + )?; } } } @@ -829,10 +920,45 @@ impl LogMessageHandler for LogClientActor { Ok(()) } - async fn flush(&mut self, _cx: &Context) -> Result<(), anyhow::Error> { - self.print_aggregators(); - self.last_flush_time = RealClock.system_time_now(); - self.next_flush_deadline = None; + async fn flush( + &mut self, + cx: &Context, + sync_version: Option, + ) -> Result<(), anyhow::Error> { + match sync_version { + None => { + self.flush_internal(); + } + Some(version) => { + if version != self.current_flush_version { + tracing::error!( + "found mismatched flush versions: got {}, expect {}; this can happen if some previous flush didn't finish fully", + version, + self.current_flush_version + ); + return Ok(()); + } + + if self.current_unflushed_procs == 0 || self.current_flush_port.is_none() { + // This is a serious issue; it's better to error out. + anyhow::bail!("found no ongoing flush request"); + } + self.current_unflushed_procs -= 1; + + tracing::debug!( + "ack sync flush: version {}; remaining procs: {}", + self.current_flush_version, + self.current_unflushed_procs + ); + + if self.current_unflushed_procs == 0 { + self.flush_internal(); + let reply = self.current_flush_port.take().unwrap(); + self.current_flush_port = None; + reply.send(cx, ()).map_err(anyhow::Error::from)?; + } + } + } Ok(()) } @@ -853,6 +979,34 @@ impl LogClientMessageHandler for LogClientActor { self.aggregate_window_sec = aggregate_window_sec; Ok(()) } + + async fn start_sync_flush( + &mut self, + cx: &Context, + expected_procs_flushed: usize, + reply: OncePortRef<()>, + version: OncePortRef, + ) -> Result<(), anyhow::Error> { + if self.current_unflushed_procs > 0 || self.current_flush_port.is_some() { + tracing::warn!( + "found unfinished ongoing flush: version {}; {} unflushed procs", + self.current_flush_version, + self.current_unflushed_procs, + ); + } + + self.current_flush_version += 1; + tracing::debug!( + "start sync flush with version {}", + self.current_flush_version + ); + self.current_flush_port = Some(reply.clone()); + self.current_unflushed_procs = expected_procs_flushed; + version + .send(cx, self.current_flush_version) + .map_err(anyhow::Error::from)?; + Ok(()) + } } #[cfg(test)] diff --git a/monarch_extension/src/logging.rs b/monarch_extension/src/logging.rs index 9ff8b208b..f58ae24a0 100644 --- a/monarch_extension/src/logging.rs +++ b/monarch_extension/src/logging.rs @@ -8,7 +8,11 @@ #![allow(unsafe_op_in_unsafe_fn)] +use std::time::Duration; + use hyperactor::ActorHandle; +use hyperactor::clock::Clock; +use hyperactor::clock::RealClock; use hyperactor_mesh::RootActorMesh; use hyperactor_mesh::actor_mesh::ActorMesh; use hyperactor_mesh::logging::LogClientActor; @@ -25,6 +29,8 @@ use pyo3::Bound; use pyo3::prelude::*; use pyo3::types::PyModule; +static FLUSH_TIMEOUT: Duration = Duration::from_secs(30); + #[pyclass( frozen, name = "LoggingMeshClient", @@ -38,6 +44,44 @@ pub struct LoggingMeshClient { client_actor: ActorHandle, } +impl LoggingMeshClient { + async fn flush_internal( + client_actor: ActorHandle, + forwarder_mesh: SharedCell>, + ) -> Result<(), anyhow::Error> { + let forwarder_inner_mesh = forwarder_mesh.borrow().map_err(anyhow::Error::msg)?; + let (reply_tx, reply_rx) = forwarder_inner_mesh + .proc_mesh() + .client() + .open_once_port::<()>(); + let (version_tx, version_rx) = forwarder_inner_mesh + .proc_mesh() + .client() + .open_once_port::(); + + // First initialize a sync flush. + client_actor.send(LogClientMessage::StartSyncFlush { + expected_procs: forwarder_inner_mesh.proc_mesh().shape().slice().len(), + reply: reply_tx.bind(), + version: version_tx.bind(), + })?; + + let version = version_rx.recv().await?; + + // Then ask all the flushers to ask the log forwarders to sync flush + forwarder_inner_mesh.cast( + forwarder_inner_mesh.proc_mesh().client(), + Selection::True, + LogForwardMessage::ForceSyncFlush { version }, + )?; + + // Finally the forwarder will send sync point back to the client, flush, and return. + reply_rx.recv().await?; + + Ok(()) + } +} + #[pymethods] impl LoggingMeshClient { #[staticmethod] @@ -48,6 +92,38 @@ impl LoggingMeshClient { let client_actor_ref = client_actor.bind(); let forwarder_mesh = proc_mesh.spawn("log_forwarder", &client_actor_ref).await?; let logger_mesh = proc_mesh.spawn("logger", &()).await?; + + // Register flush_internal as a on-stop callback + let client_actor_for_callback = client_actor.clone(); + let forwarder_mesh_for_callback = forwarder_mesh.clone(); + proc_mesh + .register_onstop_callback(|| async move { + match RealClock + .timeout( + FLUSH_TIMEOUT, + Self::flush_internal( + client_actor_for_callback, + forwarder_mesh_for_callback, + ), + ) + .await + { + Ok(Ok(())) => { + tracing::debug!("flush completed successfully during shutdown"); + } + Ok(Err(e)) => { + tracing::error!("error during flush: {}", e); + } + Err(_) => { + tracing::error!( + "flush timed out after {} seconds during shutdown", + FLUSH_TIMEOUT.as_secs() + ); + } + } + }) + .await?; + Ok(Self { forwarder_mesh, logger_mesh, @@ -65,7 +141,7 @@ impl LoggingMeshClient { ) -> PyResult<()> { if aggregate_window_sec.is_some() && !stream_to_client { return Err(PyErr::new::( - "Cannot set aggregate window without streaming to client".to_string(), + "cannot set aggregate window without streaming to client".to_string(), )); } @@ -97,6 +173,18 @@ impl LoggingMeshClient { Ok(()) } + + // A sync flush mechanism for the client make sure all the stdout/stderr are streamed back and flushed. + fn flush(&self) -> PyResult { + let forwarder_mesh = self.forwarder_mesh.clone(); + let client_actor = self.client_actor.clone(); + + PyPythonTask::new(async move { + Self::flush_internal(client_actor, forwarder_mesh) + .await + .map_err(|e| PyErr::new::(e.to_string())) + }) + } } impl Drop for LoggingMeshClient { diff --git a/monarch_hyperactor/src/proc_mesh.rs b/monarch_hyperactor/src/proc_mesh.rs index 3a0a22907..74e22d54f 100644 --- a/monarch_hyperactor/src/proc_mesh.rs +++ b/monarch_hyperactor/src/proc_mesh.rs @@ -38,6 +38,8 @@ use pyo3::types::PyType; use tokio::sync::Mutex; use tokio::sync::mpsc; +type OnStopCallback = Box Box + Send> + Send>; + use crate::actor_mesh::PythonActorMesh; use crate::actor_mesh::PythonActorMeshImpl; use crate::alloc::PyAlloc; @@ -55,6 +57,7 @@ pub struct TrackedProcMesh { inner: SharedCellRef, cell: SharedCell, children: SharedCellPool, + onstop_callbacks: Arc>>, } impl Debug for TrackedProcMesh { @@ -77,6 +80,7 @@ impl From for TrackedProcMesh { inner, cell, children: SharedCellPool::new(), + onstop_callbacks: Arc::new(Mutex::new(Vec::new())), } } } @@ -107,8 +111,25 @@ impl TrackedProcMesh { self.inner.client_proc() } - pub fn into_inner(self) -> (SharedCell, SharedCellPool) { - (self.cell, self.children) + pub fn into_inner( + self, + ) -> ( + SharedCell, + SharedCellPool, + Arc>>, + ) { + (self.cell, self.children, self.onstop_callbacks) + } + + /// Register a callback to be called when this TrackedProcMesh is stopped + pub async fn register_onstop_callback(&self, callback: F) -> Result<(), anyhow::Error> + where + F: FnOnce() -> Fut + Send + 'static, + Fut: std::future::Future + Send + 'static, + { + let mut callbacks = self.onstop_callbacks.lock().await; + callbacks.push(Box::new(|| Box::new(callback()))); + Ok(()) } } @@ -230,7 +251,17 @@ impl PyProcMesh { let tracked_proc_mesh = inner.take().await.map_err(|e| { PyRuntimeError::new_err(format!("`ProcMesh` has already been stopped: {}", e)) })?; - let (proc_mesh, children) = tracked_proc_mesh.into_inner(); + let (proc_mesh, children, drop_callbacks) = tracked_proc_mesh.into_inner(); + + // Call all registered drop callbacks before stopping + let mut callbacks = drop_callbacks.lock().await; + let callbacks_to_call = callbacks.drain(..).collect::>(); + drop(callbacks); // Release the lock + + for callback in callbacks_to_call { + let future = callback(); + std::pin::Pin::from(future).await; + } // Now we discard all in-flight actor meshes. After this, the `ProcMesh` should be "unused". // Discarding actor meshes that have been individually stopped will result in an expected error @@ -488,3 +519,196 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul hyperactor_mod.add_class::()?; Ok(()) } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::sync::atomic::AtomicBool; + use std::sync::atomic::AtomicU32; + use std::sync::atomic::Ordering; + + use anyhow::Result; + use hyperactor_mesh::alloc::AllocSpec; + use hyperactor_mesh::alloc::Allocator; + use hyperactor_mesh::alloc::local::LocalAllocator; + use hyperactor_mesh::proc_mesh::ProcMesh; + use ndslice::extent; + use tokio::sync::Mutex; + + use super::*; + + #[tokio::test] + async fn test_register_onstop_callback_single() -> Result<()> { + // Create a TrackedProcMesh + let alloc = LocalAllocator + .allocate(AllocSpec { + extent: extent! { replica = 1 }, + constraints: Default::default(), + }) + .await?; + + let mut proc_mesh = ProcMesh::allocate(alloc).await?; + + // Extract events before wrapping in TrackedProcMesh + let events = proc_mesh.events().unwrap(); + let proc_events_cell = SharedCell::from(tokio::sync::Mutex::new(events)); + + let tracked_proc_mesh = TrackedProcMesh::from(proc_mesh); + + // Create a flag to track if callback was executed + let callback_executed = Arc::new(AtomicBool::new(false)); + let callback_executed_clone = callback_executed.clone(); + + // Register a callback + tracked_proc_mesh + .register_onstop_callback(move || { + let flag = callback_executed_clone.clone(); + async move { + flag.store(true, Ordering::SeqCst); + } + }) + .await?; + + // Create a SharedCell for stop_mesh + let tracked_proc_mesh_cell = SharedCell::from(tracked_proc_mesh); + + // Call stop_mesh (this should trigger the callback) + PyProcMesh::stop_mesh(tracked_proc_mesh_cell, proc_events_cell).await?; + + // Verify the callback was executed + assert!( + callback_executed.load(Ordering::SeqCst), + "Callback should have been executed" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_register_onstop_callback_multiple() -> Result<()> { + // Create a TrackedProcMesh + let alloc = LocalAllocator + .allocate(AllocSpec { + extent: extent! { replica = 1 }, + constraints: Default::default(), + }) + .await?; + + let mut proc_mesh = ProcMesh::allocate(alloc).await?; + + // Extract events before wrapping in TrackedProcMesh + let events = proc_mesh.events().unwrap(); + let proc_events_cell = SharedCell::from(tokio::sync::Mutex::new(events)); + + let tracked_proc_mesh = TrackedProcMesh::from(proc_mesh); + + // Create counters to track callback executions + let callback_count = Arc::new(AtomicU32::new(0)); + let execution_order = Arc::new(Mutex::new(Vec::::new())); + + // Register multiple callbacks + for i in 1..=3 { + let count = callback_count.clone(); + let order = execution_order.clone(); + tracked_proc_mesh + .register_onstop_callback(move || { + let count_clone = count.clone(); + let order_clone = order.clone(); + async move { + count_clone.fetch_add(1, Ordering::SeqCst); + let mut order_vec = order_clone.lock().await; + order_vec.push(i); + } + }) + .await?; + } + + // Create a SharedCell for stop_mesh + let tracked_proc_mesh_cell = SharedCell::from(tracked_proc_mesh); + + // Call stop_mesh (this should trigger all callbacks) + PyProcMesh::stop_mesh(tracked_proc_mesh_cell, proc_events_cell).await?; + + // Verify all callbacks were executed + assert_eq!( + callback_count.load(Ordering::SeqCst), + 3, + "All 3 callbacks should have been executed" + ); + + // Verify execution order (callbacks should be executed in registration order) + let order_vec = execution_order.lock().await; + assert_eq!( + *order_vec, + vec![1, 2, 3], + "Callbacks should be executed in registration order" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_register_onstop_callback_error_handling() -> Result<()> { + // Create a TrackedProcMesh + let alloc = LocalAllocator + .allocate(AllocSpec { + extent: extent! { replica = 1 }, + constraints: Default::default(), + }) + .await?; + + let mut proc_mesh = ProcMesh::allocate(alloc).await?; + + // Extract events before wrapping in TrackedProcMesh + let events = proc_mesh.events().unwrap(); + let proc_events_cell = SharedCell::from(tokio::sync::Mutex::new(events)); + + let tracked_proc_mesh = TrackedProcMesh::from(proc_mesh); + + // Create flags to track callback executions + let callback1_executed = Arc::new(AtomicBool::new(false)); + let callback2_executed = Arc::new(AtomicBool::new(false)); + + let callback1_executed_clone = callback1_executed.clone(); + let callback2_executed_clone = callback2_executed.clone(); + + // Register a callback that panics + tracked_proc_mesh + .register_onstop_callback(move || { + let flag = callback1_executed_clone.clone(); + async move { + flag.store(true, Ordering::SeqCst); + // This callback completes successfully + } + }) + .await?; + + // Register another callback that should still execute even if the first one had issues + tracked_proc_mesh + .register_onstop_callback(move || { + let flag = callback2_executed_clone.clone(); + async move { + flag.store(true, Ordering::SeqCst); + } + }) + .await?; + + // Create a SharedCell for stop_mesh + let tracked_proc_mesh_cell = SharedCell::from(tracked_proc_mesh); + + // Call stop_mesh (this should trigger both callbacks) + PyProcMesh::stop_mesh(tracked_proc_mesh_cell, proc_events_cell).await?; + + // Verify both callbacks were executed + assert!( + callback1_executed.load(Ordering::SeqCst), + "First callback should have been executed" + ); + assert!( + callback2_executed.load(Ordering::SeqCst), + "Second callback should have been executed" + ); + + Ok(()) + } +} diff --git a/python/monarch/_rust_bindings/monarch_extension/logging.pyi b/python/monarch/_rust_bindings/monarch_extension/logging.pyi index 5d6f11960..fa3d732af 100644 --- a/python/monarch/_rust_bindings/monarch_extension/logging.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/logging.pyi @@ -21,3 +21,4 @@ class LoggingMeshClient: def set_mode( self, stream_to_client: bool, aggregate_window_sec: int | None, level: int ) -> None: ... + def flush(self) -> PythonTask[None]: ... diff --git a/python/monarch/_src/actor/logging.py b/python/monarch/_src/actor/logging.py new file mode 100644 index 000000000..f56003bb5 --- /dev/null +++ b/python/monarch/_src/actor/logging.py @@ -0,0 +1,94 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import gc +import logging + +from typing import Callable + +from monarch._rust_bindings.monarch_extension.logging import LoggingMeshClient + +from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh as HyProcMesh +from monarch._src.actor.future import Future + +IN_IPYTHON = False +try: + # Check if we are in ipython environment + # pyre-ignore[21] + from IPython import get_ipython + + # pyre-ignore[21] + from IPython.core.interactiveshell import ExecutionResult + + IN_IPYTHON = get_ipython() is not None +except ImportError: + pass + + +class LoggingManager: + def __init__(self) -> None: + self._logging_mesh_client: LoggingMeshClient | None = None + self._ipython_flush_logs_handler: Callable[..., None] | None = None + + async def init(self, proc_mesh: HyProcMesh) -> None: + if self._logging_mesh_client is not None: + return + + self._logging_mesh_client = await LoggingMeshClient.spawn(proc_mesh=proc_mesh) + self._logging_mesh_client.set_mode( + stream_to_client=True, + aggregate_window_sec=3, + level=logging.INFO, + ) + + if IN_IPYTHON: + # For ipython environment, a cell can end fast with threads running in background. + # Flush all the ongoing logs proactively to avoid missing logs. + assert self._logging_mesh_client is not None + logging_client: LoggingMeshClient = self._logging_mesh_client + ipython = get_ipython() + + # pyre-ignore[11] + def flush_logs(_: ExecutionResult) -> None: + try: + Future(coro=logging_client.flush().spawn().task()).get(3) + except TimeoutError: + # We need to prevent failed proc meshes not coming back + pass + + # Force to recycle previous undropped proc_mesh. + # Otherwise, we may end up with unregisterd dead callbacks. + gc.collect() + + # Store the handler reference so we can unregister it later + self._ipython_flush_logs_handler = flush_logs + ipython.events.register("post_run_cell", flush_logs) + + async def logging_option( + self, + stream_to_client: bool = True, + aggregate_window_sec: int | None = 3, + level: int = logging.INFO, + ) -> None: + if level < 0 or level > 255: + raise ValueError("Invalid logging level: {}".format(level)) + + assert self._logging_mesh_client is not None + self._logging_mesh_client.set_mode( + stream_to_client=stream_to_client, + aggregate_window_sec=aggregate_window_sec, + level=level, + ) + + def stop(self) -> None: + if self._ipython_flush_logs_handler is not None: + assert IN_IPYTHON + ipython = get_ipython() + assert ipython is not None + ipython.events.unregister("post_run_cell", self._ipython_flush_logs_handler) + self._ipython_flush_logs_handler = None diff --git a/python/monarch/_src/actor/proc_mesh.py b/python/monarch/_src/actor/proc_mesh.py index 6aee0a973..a0b5e5023 100644 --- a/python/monarch/_src/actor/proc_mesh.py +++ b/python/monarch/_src/actor/proc_mesh.py @@ -33,7 +33,6 @@ ) from weakref import WeakValueDictionary -from monarch._rust_bindings.monarch_extension.logging import LoggingMeshClient from monarch._rust_bindings.monarch_hyperactor.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension Alloc, AllocConstraints, @@ -67,10 +66,12 @@ from monarch._src.actor.endpoint import endpoint from monarch._src.actor.future import DeprecatedNotAFuture, Future +from monarch._src.actor.logging import LoggingManager from monarch._src.actor.shape import MeshTrait from monarch.tools.config import Workspace from monarch.tools.utils import conda as conda_utils + HAS_TENSOR_ENGINE = False try: # Torch is needed for tensor engine @@ -191,7 +192,7 @@ def __init__( # of whether this is a slice of a real proc_meshg self._slice = False self._code_sync_client: Optional[CodeSyncMeshClient] = None - self._logging_mesh_client: Optional[LoggingMeshClient] = None + self._logging_manager: LoggingManager = LoggingManager() self._maybe_device_mesh: Optional["DeviceMesh"] = _device_mesh self._stopped = False self._controller_controller: Optional["_ControllerController"] = None @@ -309,14 +310,8 @@ async def task( ) -> HyProcMesh: hy_proc_mesh = await hy_proc_mesh_task - pm._logging_mesh_client = await LoggingMeshClient.spawn( - proc_mesh=hy_proc_mesh - ) - pm._logging_mesh_client.set_mode( - stream_to_client=True, - aggregate_window_sec=3, - level=logging.INFO, - ) + # logging mesh is only makes sense with forked (remote or local) processes + await pm._logging_manager.init(hy_proc_mesh) if setup_actor is not None: await setup_actor.setup.call() @@ -483,12 +478,9 @@ async def logging_option( Returns: None """ - if level < 0 or level > 255: - raise ValueError("Invalid logging level: {}".format(level)) await self.initialized - assert self._logging_mesh_client is not None - self._logging_mesh_client.set_mode( + await self._logging_manager.logging_option( stream_to_client=stream_to_client, aggregate_window_sec=aggregate_window_sec, level=level, @@ -501,6 +493,8 @@ async def __aenter__(self) -> "ProcMesh": def stop(self) -> Future[None]: async def _stop_nonblocking() -> None: + self._logging_manager.stop() + await (await self._proc_mesh).stop_nonblocking() self._stopped = True @@ -517,6 +511,8 @@ async def __aexit__( # Finalizer to check if the proc mesh was closed properly. def __del__(self) -> None: if not self._stopped: + self._logging_manager.stop() + warnings.warn( f"unstopped ProcMesh {self!r}", ResourceWarning, diff --git a/python/tests/python_actor_test_binary.py b/python/tests/python_actor_test_binary.py index 12a10b0f5..e06a5952e 100644 --- a/python/tests/python_actor_test_binary.py +++ b/python/tests/python_actor_test_binary.py @@ -40,8 +40,7 @@ async def _flush_logs() -> None: for _ in range(5): await am.print.call("has print streaming") - # TODO: will soon be removed by D80051803 - await asyncio.sleep(2) + await pm.stop() @main.command("flush-logs") diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 5d9ce8d98..0c05f67f0 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -6,6 +6,7 @@ # pyre-unsafe import asyncio +import gc import importlib.resources import logging import operator @@ -27,6 +28,7 @@ from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask from monarch._src.actor.actor_mesh import ActorMesh, Channel, Port +from monarch._src.actor.future import Future from monarch.actor import ( Accumulator, @@ -559,8 +561,7 @@ async def test_actor_log_streaming() -> None: await am.print.call("has print streaming too") await am.log.call("has log streaming as level matched") - # Give it some time to reflect and aggregate - await asyncio.sleep(1) + await pm.stop() # Flush all outputs stdout_file.flush() @@ -675,7 +676,8 @@ async def test_logging_option_defaults() -> None: for _ in range(5): await am.print.call("print streaming") await am.log.call("log streaming") - await asyncio.sleep(4) + + await pm.stop() # Flush all outputs stdout_file.flush() @@ -728,6 +730,147 @@ async def test_logging_option_defaults() -> None: pass +# oss_skip: pytest keeps complaining about mocking get_ipython module +@pytest.mark.oss_skip +@pytest.mark.timeout(180) +async def test_flush_logs_ipython() -> None: + """Test that logs are flushed when get_ipython is available and post_run_cell event is triggered.""" + # Save original file descriptors + original_stdout_fd = os.dup(1) # stdout + + try: + # Create temporary files to capture output + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file: + stdout_path = stdout_file.name + + # Redirect file descriptors to our temp files + os.dup2(stdout_file.fileno(), 1) + + # Also redirect Python's sys.stdout + original_sys_stdout = sys.stdout + sys.stdout = stdout_file + + try: + # Mock IPython environment + class MockExecutionResult: + pass + + class MockEvents: + def __init__(self): + self.callbacks = {} + self.registers = 0 + self.unregisters = 0 + + def register(self, event_name, callback): + if event_name not in self.callbacks: + self.callbacks[event_name] = [] + self.callbacks[event_name].append(callback) + self.registers += 1 + + def unregister(self, event_name, callback): + if event_name not in self.callbacks: + raise ValueError(f"Event {event_name} not registered") + assert callback in self.callbacks[event_name] + self.callbacks[event_name].remove(callback) + self.unregisters += 1 + + def trigger(self, event_name, *args, **kwargs): + if event_name in self.callbacks: + for callback in self.callbacks[event_name]: + callback(*args, **kwargs) + + class MockIPython: + def __init__(self): + self.events = MockEvents() + + mock_ipython = MockIPython() + + with unittest.mock.patch( + "monarch._src.actor.logging.get_ipython", + lambda: mock_ipython, + ), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True): + # Make sure we can register and unregister callbacks + for _ in range(3): + pm1 = await proc_mesh(gpus=2) + pm2 = await proc_mesh(gpus=2) + am1 = await pm1.spawn("printer", Printer) + am2 = await pm2.spawn("printer", Printer) + + # Set aggregation window to ensure logs are buffered + await pm1.logging_option( + stream_to_client=True, aggregate_window_sec=600 + ) + await pm2.logging_option( + stream_to_client=True, aggregate_window_sec=600 + ) + await asyncio.sleep(1) + + # Generate some logs that will be aggregated + for _ in range(5): + await am1.print.call("ipython1 test log") + await am2.print.call("ipython2 test log") + + # Trigger the post_run_cell event which should flush logs + mock_ipython.events.trigger( + "post_run_cell", MockExecutionResult() + ) + + gc.collect() + + assert mock_ipython.events.registers == 6 + # TODO: figure out why the latest unregister is not called + assert mock_ipython.events.unregisters == 4 + assert len(mock_ipython.events.callbacks["post_run_cell"]) == 2 + + # Flush all outputs + stdout_file.flush() + os.fsync(stdout_file.fileno()) + + finally: + # Restore Python's sys.stdout + sys.stdout = original_sys_stdout + + # Restore original file descriptors + os.dup2(original_stdout_fd, 1) + + # Read the captured output + with open(stdout_path, "r") as f: + stdout_content = f.read() + + # TODO: there are quite a lot of code dups and boilerplate; make them contextmanager utils + + # Clean up temp files + os.unlink(stdout_path) + + # Verify that logs were flushed when the post_run_cell event was triggered + # We should see the aggregated logs in the output + assert ( + len( + re.findall( + r"\[10 similar log lines\].*ipython1 test log", stdout_content + ) + ) + == 3 + ), stdout_content + + assert ( + len( + re.findall( + r"\[10 similar log lines\].*ipython2 test log", stdout_content + ) + ) + == 3 + ), stdout_content + + finally: + # Ensure file descriptors are restored even if something goes wrong + try: + os.dup2(original_stdout_fd, 1) + os.close(original_stdout_fd) + except OSError: + pass + + # oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited @pytest.mark.oss_skip async def test_flush_logs_fast_exit() -> None: @@ -801,8 +944,7 @@ async def test_flush_on_disable_aggregation() -> None: for _ in range(5): await am.print.call("single log line") - # Wait a bit to ensure flush completes - await asyncio.sleep(1) + await pm.stop() # Flush all outputs stdout_file.flush() @@ -846,6 +988,32 @@ async def test_flush_on_disable_aggregation() -> None: pass +@pytest.mark.timeout(120) +async def test_multiple_ongoing_flushes_no_deadlock() -> None: + """ + The goal is to make sure when a user sends multiple sync flushes, we are not deadlocked. + Because now a flush call is purely sync, it is very easy to get into a deadlock. + So we assert the last flush call will not get into such a state. + """ + pm = await proc_mesh(gpus=4) + am = await pm.spawn("printer", Printer) + + # Generate some logs that will be aggregated but not flushed immediately + for _ in range(10): + await am.print.call("aggregated log line") + + log_mesh = pm._logging_manager._logging_mesh_client + assert log_mesh is not None + futures = [] + for _ in range(5): + # FIXME: the order of futures doesn't necessarily mean the order of flushes due to the async nature. + await asyncio.sleep(0.1) + futures.append(Future(coro=log_mesh.flush().spawn().task())) + + # The last flush should not block + futures[-1].get() + + @pytest.mark.timeout(60) async def test_adjust_aggregation_window() -> None: """Test that the flush deadline is updated when the aggregation window is adjusted. @@ -886,8 +1054,7 @@ async def test_adjust_aggregation_window() -> None: for _ in range(3): await am.print.call("second batch of logs") - # Wait just enough time for the shorter window to trigger a flush - await asyncio.sleep(1) + await pm.stop() # Flush all outputs stdout_file.flush()