Skip to content

Commit ab890a3

Browse files
highkerfacebook-github-bot
authored andcommitted
sync flush logs upon mesh stop
Summary: force sync flush upon mesh stop Differential Revision: D80310284
1 parent 9177ca1 commit ab890a3

File tree

4 files changed

+271
-25
lines changed

4 files changed

+271
-25
lines changed

monarch_extension/src/logging.rs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88

99
#![allow(unsafe_op_in_unsafe_fn)]
1010

11+
use std::time::Duration;
12+
1113
use hyperactor::ActorHandle;
14+
use hyperactor::clock::Clock;
15+
use hyperactor::clock::RealClock;
1216
use hyperactor_mesh::RootActorMesh;
1317
use hyperactor_mesh::actor_mesh::ActorMesh;
1418
use hyperactor_mesh::logging::LogClientActor;
@@ -27,6 +31,8 @@ use pyo3::Bound;
2731
use pyo3::prelude::*;
2832
use pyo3::types::PyModule;
2933

34+
static FLUSH_TIMEOUT: Duration = Duration::from_secs(30);
35+
3036
#[pyclass(
3137
frozen,
3238
name = "LoggingMeshClient",
@@ -89,6 +95,38 @@ impl LoggingMeshClient {
8995
let forwarder_mesh = proc_mesh.spawn("log_forwarder", &client_actor_ref).await?;
9096
let flush_mesh = proc_mesh.spawn("log_flusher", &()).await?;
9197
let logger_mesh = proc_mesh.spawn("logger", &()).await?;
98+
99+
// Register flush_internal as a on-stop callback
100+
let client_actor_for_callback = client_actor.clone();
101+
let flush_mesh_for_callback = flush_mesh.clone();
102+
proc_mesh
103+
.register_onstop_callback(|| async move {
104+
match RealClock
105+
.timeout(
106+
FLUSH_TIMEOUT,
107+
Self::flush_internal(
108+
client_actor_for_callback,
109+
flush_mesh_for_callback,
110+
),
111+
)
112+
.await
113+
{
114+
Ok(Ok(())) => {
115+
tracing::debug!("flush completed successfully during shutdown");
116+
}
117+
Ok(Err(e)) => {
118+
tracing::error!("error during flush: {}", e);
119+
}
120+
Err(_) => {
121+
tracing::error!(
122+
"flush timed out after {} seconds during shutdown",
123+
FLUSH_TIMEOUT.as_secs()
124+
);
125+
}
126+
}
127+
})
128+
.await?;
129+
92130
Ok(Self {
93131
forwarder_mesh,
94132
flush_mesh,
@@ -107,7 +145,7 @@ impl LoggingMeshClient {
107145
) -> PyResult<()> {
108146
if aggregate_window_sec.is_some() && !stream_to_client {
109147
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
110-
"Cannot set aggregate window without streaming to client".to_string(),
148+
"cannot set aggregate window without streaming to client".to_string(),
111149
));
112150
}
113151

monarch_hyperactor/src/proc_mesh.rs

Lines changed: 227 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ use pyo3::types::PyType;
3838
use tokio::sync::Mutex;
3939
use tokio::sync::mpsc;
4040

41+
type OnStopCallback = Box<dyn FnOnce() -> Box<dyn std::future::Future<Output = ()> + Send> + Send>;
42+
4143
use crate::actor_mesh::PythonActorMesh;
4244
use crate::actor_mesh::PythonActorMeshImpl;
4345
use crate::alloc::PyAlloc;
@@ -55,6 +57,7 @@ pub struct TrackedProcMesh {
5557
inner: SharedCellRef<ProcMesh>,
5658
cell: SharedCell<ProcMesh>,
5759
children: SharedCellPool,
60+
onstop_callbacks: Arc<Mutex<Vec<OnStopCallback>>>,
5861
}
5962

6063
impl Debug for TrackedProcMesh {
@@ -77,6 +80,7 @@ impl From<ProcMesh> for TrackedProcMesh {
7780
inner,
7881
cell,
7982
children: SharedCellPool::new(),
83+
onstop_callbacks: Arc::new(Mutex::new(Vec::new())),
8084
}
8185
}
8286
}
@@ -107,8 +111,25 @@ impl TrackedProcMesh {
107111
self.inner.client_proc()
108112
}
109113

110-
pub fn into_inner(self) -> (SharedCell<ProcMesh>, SharedCellPool) {
111-
(self.cell, self.children)
114+
pub fn into_inner(
115+
self,
116+
) -> (
117+
SharedCell<ProcMesh>,
118+
SharedCellPool,
119+
Arc<Mutex<Vec<OnStopCallback>>>,
120+
) {
121+
(self.cell, self.children, self.onstop_callbacks)
122+
}
123+
124+
/// Register a callback to be called when this TrackedProcMesh is stopped
125+
pub async fn register_onstop_callback<F, Fut>(&self, callback: F) -> Result<(), anyhow::Error>
126+
where
127+
F: FnOnce() -> Fut + Send + 'static,
128+
Fut: std::future::Future<Output = ()> + Send + 'static,
129+
{
130+
let mut callbacks = self.onstop_callbacks.lock().await;
131+
callbacks.push(Box::new(|| Box::new(callback())));
132+
Ok(())
112133
}
113134
}
114135

@@ -230,7 +251,17 @@ impl PyProcMesh {
230251
let tracked_proc_mesh = inner.take().await.map_err(|e| {
231252
PyRuntimeError::new_err(format!("`ProcMesh` has already been stopped: {}", e))
232253
})?;
233-
let (proc_mesh, children) = tracked_proc_mesh.into_inner();
254+
let (proc_mesh, children, drop_callbacks) = tracked_proc_mesh.into_inner();
255+
256+
// Call all registered drop callbacks before stopping
257+
let mut callbacks = drop_callbacks.lock().await;
258+
let callbacks_to_call = callbacks.drain(..).collect::<Vec<_>>();
259+
drop(callbacks); // Release the lock
260+
261+
for callback in callbacks_to_call {
262+
let future = callback();
263+
std::pin::Pin::from(future).await;
264+
}
234265

235266
// Now we discard all in-flight actor meshes. After this, the `ProcMesh` should be "unused".
236267
children.discard_all().await?;
@@ -486,3 +517,196 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
486517
hyperactor_mod.add_class::<PyProcEvent>()?;
487518
Ok(())
488519
}
520+
521+
#[cfg(test)]
522+
mod tests {
523+
use std::sync::Arc;
524+
use std::sync::atomic::AtomicBool;
525+
use std::sync::atomic::AtomicU32;
526+
use std::sync::atomic::Ordering;
527+
528+
use anyhow::Result;
529+
use hyperactor_mesh::alloc::AllocSpec;
530+
use hyperactor_mesh::alloc::Allocator;
531+
use hyperactor_mesh::alloc::local::LocalAllocator;
532+
use hyperactor_mesh::proc_mesh::ProcMesh;
533+
use ndslice::extent;
534+
use tokio::sync::Mutex;
535+
536+
use super::*;
537+
538+
#[tokio::test]
539+
async fn test_register_onstop_callback_single() -> Result<()> {
540+
// Create a TrackedProcMesh
541+
let alloc = LocalAllocator
542+
.allocate(AllocSpec {
543+
extent: extent! { replica = 1 },
544+
constraints: Default::default(),
545+
})
546+
.await?;
547+
548+
let mut proc_mesh = ProcMesh::allocate(alloc).await?;
549+
550+
// Extract events before wrapping in TrackedProcMesh
551+
let events = proc_mesh.events().unwrap();
552+
let proc_events_cell = SharedCell::from(tokio::sync::Mutex::new(events));
553+
554+
let tracked_proc_mesh = TrackedProcMesh::from(proc_mesh);
555+
556+
// Create a flag to track if callback was executed
557+
let callback_executed = Arc::new(AtomicBool::new(false));
558+
let callback_executed_clone = callback_executed.clone();
559+
560+
// Register a callback
561+
tracked_proc_mesh
562+
.register_onstop_callback(move || {
563+
let flag = callback_executed_clone.clone();
564+
async move {
565+
flag.store(true, Ordering::SeqCst);
566+
}
567+
})
568+
.await?;
569+
570+
// Create a SharedCell<TrackedProcMesh> for stop_mesh
571+
let tracked_proc_mesh_cell = SharedCell::from(tracked_proc_mesh);
572+
573+
// Call stop_mesh (this should trigger the callback)
574+
PyProcMesh::stop_mesh(tracked_proc_mesh_cell, proc_events_cell).await?;
575+
576+
// Verify the callback was executed
577+
assert!(
578+
callback_executed.load(Ordering::SeqCst),
579+
"Callback should have been executed"
580+
);
581+
582+
Ok(())
583+
}
584+
585+
#[tokio::test]
586+
async fn test_register_onstop_callback_multiple() -> Result<()> {
587+
// Create a TrackedProcMesh
588+
let alloc = LocalAllocator
589+
.allocate(AllocSpec {
590+
extent: extent! { replica = 1 },
591+
constraints: Default::default(),
592+
})
593+
.await?;
594+
595+
let mut proc_mesh = ProcMesh::allocate(alloc).await?;
596+
597+
// Extract events before wrapping in TrackedProcMesh
598+
let events = proc_mesh.events().unwrap();
599+
let proc_events_cell = SharedCell::from(tokio::sync::Mutex::new(events));
600+
601+
let tracked_proc_mesh = TrackedProcMesh::from(proc_mesh);
602+
603+
// Create counters to track callback executions
604+
let callback_count = Arc::new(AtomicU32::new(0));
605+
let execution_order = Arc::new(Mutex::new(Vec::<u32>::new()));
606+
607+
// Register multiple callbacks
608+
for i in 1..=3 {
609+
let count = callback_count.clone();
610+
let order = execution_order.clone();
611+
tracked_proc_mesh
612+
.register_onstop_callback(move || {
613+
let count_clone = count.clone();
614+
let order_clone = order.clone();
615+
async move {
616+
count_clone.fetch_add(1, Ordering::SeqCst);
617+
let mut order_vec = order_clone.lock().await;
618+
order_vec.push(i);
619+
}
620+
})
621+
.await?;
622+
}
623+
624+
// Create a SharedCell<TrackedProcMesh> for stop_mesh
625+
let tracked_proc_mesh_cell = SharedCell::from(tracked_proc_mesh);
626+
627+
// Call stop_mesh (this should trigger all callbacks)
628+
PyProcMesh::stop_mesh(tracked_proc_mesh_cell, proc_events_cell).await?;
629+
630+
// Verify all callbacks were executed
631+
assert_eq!(
632+
callback_count.load(Ordering::SeqCst),
633+
3,
634+
"All 3 callbacks should have been executed"
635+
);
636+
637+
// Verify execution order (callbacks should be executed in registration order)
638+
let order_vec = execution_order.lock().await;
639+
assert_eq!(
640+
*order_vec,
641+
vec![1, 2, 3],
642+
"Callbacks should be executed in registration order"
643+
);
644+
645+
Ok(())
646+
}
647+
648+
#[tokio::test]
649+
async fn test_register_onstop_callback_error_handling() -> Result<()> {
650+
// Create a TrackedProcMesh
651+
let alloc = LocalAllocator
652+
.allocate(AllocSpec {
653+
extent: extent! { replica = 1 },
654+
constraints: Default::default(),
655+
})
656+
.await?;
657+
658+
let mut proc_mesh = ProcMesh::allocate(alloc).await?;
659+
660+
// Extract events before wrapping in TrackedProcMesh
661+
let events = proc_mesh.events().unwrap();
662+
let proc_events_cell = SharedCell::from(tokio::sync::Mutex::new(events));
663+
664+
let tracked_proc_mesh = TrackedProcMesh::from(proc_mesh);
665+
666+
// Create flags to track callback executions
667+
let callback1_executed = Arc::new(AtomicBool::new(false));
668+
let callback2_executed = Arc::new(AtomicBool::new(false));
669+
670+
let callback1_executed_clone = callback1_executed.clone();
671+
let callback2_executed_clone = callback2_executed.clone();
672+
673+
// Register a callback that panics
674+
tracked_proc_mesh
675+
.register_onstop_callback(move || {
676+
let flag = callback1_executed_clone.clone();
677+
async move {
678+
flag.store(true, Ordering::SeqCst);
679+
// This callback completes successfully
680+
}
681+
})
682+
.await?;
683+
684+
// Register another callback that should still execute even if the first one had issues
685+
tracked_proc_mesh
686+
.register_onstop_callback(move || {
687+
let flag = callback2_executed_clone.clone();
688+
async move {
689+
flag.store(true, Ordering::SeqCst);
690+
}
691+
})
692+
.await?;
693+
694+
// Create a SharedCell<TrackedProcMesh> for stop_mesh
695+
let tracked_proc_mesh_cell = SharedCell::from(tracked_proc_mesh);
696+
697+
// Call stop_mesh (this should trigger both callbacks)
698+
PyProcMesh::stop_mesh(tracked_proc_mesh_cell, proc_events_cell).await?;
699+
700+
// Verify both callbacks were executed
701+
assert!(
702+
callback1_executed.load(Ordering::SeqCst),
703+
"First callback should have been executed"
704+
);
705+
assert!(
706+
callback2_executed.load(Ordering::SeqCst),
707+
"Second callback should have been executed"
708+
);
709+
710+
Ok(())
711+
}
712+
}

python/tests/python_actor_test_binary.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import logging
1111

1212
import click
13-
from monarch._src.actor.future import Future
1413

1514
from monarch.actor import Actor, endpoint, proc_mesh
1615

@@ -41,10 +40,7 @@ async def _flush_logs() -> None:
4140
for _ in range(5):
4241
await am.print.call("has print streaming")
4342

44-
# TODO: remove this completely once we hook the flush logic upon dropping device_mesh
45-
log_mesh = pm._logging_mesh_client
46-
assert log_mesh is not None
47-
Future(coro=log_mesh.flush().spawn().task()).get()
43+
await pm.stop()
4844

4945

5046
@main.command("flush-logs")

0 commit comments

Comments
 (0)