Skip to content

Commit 66985b3

Browse files
James Sunfacebook-github-bot
authored andcommitted
sync flush logs upon mesh stop (#885)
Summary: force sync flush upon mesh stop Differential Revision: D80310284
1 parent 65ff477 commit 66985b3

File tree

4 files changed

+271
-25
lines changed

4 files changed

+271
-25
lines changed

monarch_extension/src/logging.rs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88

99
#![allow(unsafe_op_in_unsafe_fn)]
1010

11+
use std::time::Duration;
12+
1113
use hyperactor::ActorHandle;
14+
use hyperactor::clock::Clock;
15+
use hyperactor::clock::RealClock;
1216
use hyperactor_mesh::RootActorMesh;
1317
use hyperactor_mesh::actor_mesh::ActorMesh;
1418
use hyperactor_mesh::logging::LogClientActor;
@@ -27,6 +31,8 @@ use pyo3::Bound;
2731
use pyo3::prelude::*;
2832
use pyo3::types::PyModule;
2933

34+
static FLUSH_TIMEOUT: Duration = Duration::from_secs(30);
35+
3036
#[pyclass(
3137
frozen,
3238
name = "LoggingMeshClient",
@@ -89,6 +95,38 @@ impl LoggingMeshClient {
8995
let forwarder_mesh = proc_mesh.spawn("log_forwarder", &client_actor_ref).await?;
9096
let flush_mesh = proc_mesh.spawn("log_flusher", &()).await?;
9197
let logger_mesh = proc_mesh.spawn("logger", &()).await?;
98+
99+
// Register flush_internal as a on-stop callback
100+
let client_actor_for_callback = client_actor.clone();
101+
let flush_mesh_for_callback = flush_mesh.clone();
102+
proc_mesh
103+
.register_onstop_callback(|| async move {
104+
match RealClock
105+
.timeout(
106+
FLUSH_TIMEOUT,
107+
Self::flush_internal(
108+
client_actor_for_callback,
109+
flush_mesh_for_callback,
110+
),
111+
)
112+
.await
113+
{
114+
Ok(Ok(())) => {
115+
tracing::debug!("flush completed successfully during shutdown");
116+
}
117+
Ok(Err(e)) => {
118+
tracing::error!("error during flush: {}", e);
119+
}
120+
Err(_) => {
121+
tracing::error!(
122+
"flush timed out after {} seconds during shutdown",
123+
FLUSH_TIMEOUT.as_secs()
124+
);
125+
}
126+
}
127+
})
128+
.await?;
129+
92130
Ok(Self {
93131
forwarder_mesh,
94132
flush_mesh,
@@ -107,7 +145,7 @@ impl LoggingMeshClient {
107145
) -> PyResult<()> {
108146
if aggregate_window_sec.is_some() && !stream_to_client {
109147
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
110-
"Cannot set aggregate window without streaming to client".to_string(),
148+
"cannot set aggregate window without streaming to client".to_string(),
111149
));
112150
}
113151

monarch_hyperactor/src/proc_mesh.rs

Lines changed: 227 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ use pyo3::types::PyType;
3838
use tokio::sync::Mutex;
3939
use tokio::sync::mpsc;
4040

41+
type OnStopCallback = Box<dyn FnOnce() -> Box<dyn std::future::Future<Output = ()> + Send> + Send>;
42+
4143
use crate::actor_mesh::PythonActorMesh;
4244
use crate::actor_mesh::PythonActorMeshImpl;
4345
use crate::alloc::PyAlloc;
@@ -55,6 +57,7 @@ pub struct TrackedProcMesh {
5557
inner: SharedCellRef<ProcMesh>,
5658
cell: SharedCell<ProcMesh>,
5759
children: SharedCellPool,
60+
onstop_callbacks: Arc<Mutex<Vec<OnStopCallback>>>,
5861
}
5962

6063
impl Debug for TrackedProcMesh {
@@ -77,6 +80,7 @@ impl From<ProcMesh> for TrackedProcMesh {
7780
inner,
7881
cell,
7982
children: SharedCellPool::new(),
83+
onstop_callbacks: Arc::new(Mutex::new(Vec::new())),
8084
}
8185
}
8286
}
@@ -107,8 +111,25 @@ impl TrackedProcMesh {
107111
self.inner.client_proc()
108112
}
109113

110-
pub fn into_inner(self) -> (SharedCell<ProcMesh>, SharedCellPool) {
111-
(self.cell, self.children)
114+
pub fn into_inner(
115+
self,
116+
) -> (
117+
SharedCell<ProcMesh>,
118+
SharedCellPool,
119+
Arc<Mutex<Vec<OnStopCallback>>>,
120+
) {
121+
(self.cell, self.children, self.onstop_callbacks)
122+
}
123+
124+
/// Register a callback to be called when this TrackedProcMesh is stopped
125+
pub async fn register_onstop_callback<F, Fut>(&self, callback: F) -> Result<(), anyhow::Error>
126+
where
127+
F: FnOnce() -> Fut + Send + 'static,
128+
Fut: std::future::Future<Output = ()> + Send + 'static,
129+
{
130+
let mut callbacks = self.onstop_callbacks.lock().await;
131+
callbacks.push(Box::new(|| Box::new(callback())));
132+
Ok(())
112133
}
113134
}
114135

@@ -230,7 +251,17 @@ impl PyProcMesh {
230251
let tracked_proc_mesh = inner.take().await.map_err(|e| {
231252
PyRuntimeError::new_err(format!("`ProcMesh` has already been stopped: {}", e))
232253
})?;
233-
let (proc_mesh, children) = tracked_proc_mesh.into_inner();
254+
let (proc_mesh, children, drop_callbacks) = tracked_proc_mesh.into_inner();
255+
256+
// Call all registered drop callbacks before stopping
257+
let mut callbacks = drop_callbacks.lock().await;
258+
let callbacks_to_call = callbacks.drain(..).collect::<Vec<_>>();
259+
drop(callbacks); // Release the lock
260+
261+
for callback in callbacks_to_call {
262+
let future = callback();
263+
std::pin::Pin::from(future).await;
264+
}
234265

235266
// Now we discard all in-flight actor meshes. After this, the `ProcMesh` should be "unused".
236267
// Discarding actor meshes that have been individually stopped will result in an expected error
@@ -488,3 +519,196 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
488519
hyperactor_mod.add_class::<PyProcEvent>()?;
489520
Ok(())
490521
}
522+
523+
#[cfg(test)]
524+
mod tests {
525+
use std::sync::Arc;
526+
use std::sync::atomic::AtomicBool;
527+
use std::sync::atomic::AtomicU32;
528+
use std::sync::atomic::Ordering;
529+
530+
use anyhow::Result;
531+
use hyperactor_mesh::alloc::AllocSpec;
532+
use hyperactor_mesh::alloc::Allocator;
533+
use hyperactor_mesh::alloc::local::LocalAllocator;
534+
use hyperactor_mesh::proc_mesh::ProcMesh;
535+
use ndslice::extent;
536+
use tokio::sync::Mutex;
537+
538+
use super::*;
539+
540+
#[tokio::test]
541+
async fn test_register_onstop_callback_single() -> Result<()> {
542+
// Create a TrackedProcMesh
543+
let alloc = LocalAllocator
544+
.allocate(AllocSpec {
545+
extent: extent! { replica = 1 },
546+
constraints: Default::default(),
547+
})
548+
.await?;
549+
550+
let mut proc_mesh = ProcMesh::allocate(alloc).await?;
551+
552+
// Extract events before wrapping in TrackedProcMesh
553+
let events = proc_mesh.events().unwrap();
554+
let proc_events_cell = SharedCell::from(tokio::sync::Mutex::new(events));
555+
556+
let tracked_proc_mesh = TrackedProcMesh::from(proc_mesh);
557+
558+
// Create a flag to track if callback was executed
559+
let callback_executed = Arc::new(AtomicBool::new(false));
560+
let callback_executed_clone = callback_executed.clone();
561+
562+
// Register a callback
563+
tracked_proc_mesh
564+
.register_onstop_callback(move || {
565+
let flag = callback_executed_clone.clone();
566+
async move {
567+
flag.store(true, Ordering::SeqCst);
568+
}
569+
})
570+
.await?;
571+
572+
// Create a SharedCell<TrackedProcMesh> for stop_mesh
573+
let tracked_proc_mesh_cell = SharedCell::from(tracked_proc_mesh);
574+
575+
// Call stop_mesh (this should trigger the callback)
576+
PyProcMesh::stop_mesh(tracked_proc_mesh_cell, proc_events_cell).await?;
577+
578+
// Verify the callback was executed
579+
assert!(
580+
callback_executed.load(Ordering::SeqCst),
581+
"Callback should have been executed"
582+
);
583+
584+
Ok(())
585+
}
586+
587+
#[tokio::test]
588+
async fn test_register_onstop_callback_multiple() -> Result<()> {
589+
// Create a TrackedProcMesh
590+
let alloc = LocalAllocator
591+
.allocate(AllocSpec {
592+
extent: extent! { replica = 1 },
593+
constraints: Default::default(),
594+
})
595+
.await?;
596+
597+
let mut proc_mesh = ProcMesh::allocate(alloc).await?;
598+
599+
// Extract events before wrapping in TrackedProcMesh
600+
let events = proc_mesh.events().unwrap();
601+
let proc_events_cell = SharedCell::from(tokio::sync::Mutex::new(events));
602+
603+
let tracked_proc_mesh = TrackedProcMesh::from(proc_mesh);
604+
605+
// Create counters to track callback executions
606+
let callback_count = Arc::new(AtomicU32::new(0));
607+
let execution_order = Arc::new(Mutex::new(Vec::<u32>::new()));
608+
609+
// Register multiple callbacks
610+
for i in 1..=3 {
611+
let count = callback_count.clone();
612+
let order = execution_order.clone();
613+
tracked_proc_mesh
614+
.register_onstop_callback(move || {
615+
let count_clone = count.clone();
616+
let order_clone = order.clone();
617+
async move {
618+
count_clone.fetch_add(1, Ordering::SeqCst);
619+
let mut order_vec = order_clone.lock().await;
620+
order_vec.push(i);
621+
}
622+
})
623+
.await?;
624+
}
625+
626+
// Create a SharedCell<TrackedProcMesh> for stop_mesh
627+
let tracked_proc_mesh_cell = SharedCell::from(tracked_proc_mesh);
628+
629+
// Call stop_mesh (this should trigger all callbacks)
630+
PyProcMesh::stop_mesh(tracked_proc_mesh_cell, proc_events_cell).await?;
631+
632+
// Verify all callbacks were executed
633+
assert_eq!(
634+
callback_count.load(Ordering::SeqCst),
635+
3,
636+
"All 3 callbacks should have been executed"
637+
);
638+
639+
// Verify execution order (callbacks should be executed in registration order)
640+
let order_vec = execution_order.lock().await;
641+
assert_eq!(
642+
*order_vec,
643+
vec![1, 2, 3],
644+
"Callbacks should be executed in registration order"
645+
);
646+
647+
Ok(())
648+
}
649+
650+
#[tokio::test]
651+
async fn test_register_onstop_callback_error_handling() -> Result<()> {
652+
// Create a TrackedProcMesh
653+
let alloc = LocalAllocator
654+
.allocate(AllocSpec {
655+
extent: extent! { replica = 1 },
656+
constraints: Default::default(),
657+
})
658+
.await?;
659+
660+
let mut proc_mesh = ProcMesh::allocate(alloc).await?;
661+
662+
// Extract events before wrapping in TrackedProcMesh
663+
let events = proc_mesh.events().unwrap();
664+
let proc_events_cell = SharedCell::from(tokio::sync::Mutex::new(events));
665+
666+
let tracked_proc_mesh = TrackedProcMesh::from(proc_mesh);
667+
668+
// Create flags to track callback executions
669+
let callback1_executed = Arc::new(AtomicBool::new(false));
670+
let callback2_executed = Arc::new(AtomicBool::new(false));
671+
672+
let callback1_executed_clone = callback1_executed.clone();
673+
let callback2_executed_clone = callback2_executed.clone();
674+
675+
// Register a callback that panics
676+
tracked_proc_mesh
677+
.register_onstop_callback(move || {
678+
let flag = callback1_executed_clone.clone();
679+
async move {
680+
flag.store(true, Ordering::SeqCst);
681+
// This callback completes successfully
682+
}
683+
})
684+
.await?;
685+
686+
// Register another callback that should still execute even if the first one had issues
687+
tracked_proc_mesh
688+
.register_onstop_callback(move || {
689+
let flag = callback2_executed_clone.clone();
690+
async move {
691+
flag.store(true, Ordering::SeqCst);
692+
}
693+
})
694+
.await?;
695+
696+
// Create a SharedCell<TrackedProcMesh> for stop_mesh
697+
let tracked_proc_mesh_cell = SharedCell::from(tracked_proc_mesh);
698+
699+
// Call stop_mesh (this should trigger both callbacks)
700+
PyProcMesh::stop_mesh(tracked_proc_mesh_cell, proc_events_cell).await?;
701+
702+
// Verify both callbacks were executed
703+
assert!(
704+
callback1_executed.load(Ordering::SeqCst),
705+
"First callback should have been executed"
706+
);
707+
assert!(
708+
callback2_executed.load(Ordering::SeqCst),
709+
"Second callback should have been executed"
710+
);
711+
712+
Ok(())
713+
}
714+
}

python/tests/python_actor_test_binary.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import logging
1111

1212
import click
13-
from monarch._src.actor.future import Future
1413

1514
from monarch.actor import Actor, endpoint, proc_mesh
1615

@@ -41,10 +40,7 @@ async def _flush_logs() -> None:
4140
for _ in range(5):
4241
await am.print.call("has print streaming")
4342

44-
# TODO: remove this completely once we hook the flush logic upon dropping device_mesh
45-
log_mesh = pm._logging_mesh_client
46-
assert log_mesh is not None
47-
Future(coro=log_mesh.flush().spawn().task()).get()
43+
await pm.stop()
4844

4945

5046
@main.command("flush-logs")

0 commit comments

Comments
 (0)