Skip to content

Commit 02df0cf

Browse files
James Sunfacebook-github-bot
authored andcommitted
check if we are in ipython env
Summary: In ipython notebook, a cell can end fast. Yet the process can still run in the background. However, the background process will not flush logs to the existing cell anymore. The patch register the flush function upon a cell exiting. Differential Revision: D79982702
1 parent 02cd25f commit 02df0cf

File tree

5 files changed

+164
-5
lines changed

5 files changed

+164
-5
lines changed

hyperactor_mesh/src/logging.rs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use hyperactor::HandleClient;
2727
use hyperactor::Handler;
2828
use hyperactor::Instance;
2929
use hyperactor::Named;
30+
use hyperactor::OncePortRef;
3031
use hyperactor::RefClient;
3132
use hyperactor::channel;
3233
use hyperactor::channel::ChannelAddr;
@@ -279,6 +280,11 @@ pub enum LogClientMessage {
279280
/// The time window in seconds to aggregate logs. If None, aggregation is disabled.
280281
aggregate_window_sec: Option<u64>,
281282
},
283+
284+
Flush {
285+
/// Synchronously flush all the logs
286+
reply: OncePortRef<()>,
287+
},
282288
}
283289

284290
/// Trait for sending logs
@@ -748,6 +754,12 @@ impl LogClientActor {
748754
OutputTarget::Stderr => eprintln!("{}", message),
749755
}
750756
}
757+
758+
fn flush_internal(&mut self) {
759+
self.print_aggregators();
760+
self.last_flush_time = RealClock.system_time_now();
761+
self.next_flush_deadline = None;
762+
}
751763
}
752764

753765
#[async_trait]
@@ -817,7 +829,7 @@ impl LogMessageHandler for LogClientActor {
817829
let new_deadline = self.last_flush_time + Duration::from_secs(window);
818830
let now = RealClock.system_time_now();
819831
if new_deadline <= now {
820-
self.flush(cx).await?;
832+
self.flush_internal();
821833
} else {
822834
let delay = new_deadline.duration_since(now)?;
823835
match self.next_flush_deadline {
@@ -842,9 +854,7 @@ impl LogMessageHandler for LogClientActor {
842854
}
843855

844856
async fn flush(&mut self, _cx: &Context<Self>) -> Result<(), anyhow::Error> {
845-
self.print_aggregators();
846-
self.last_flush_time = RealClock.system_time_now();
847-
self.next_flush_deadline = None;
857+
self.flush_internal();
848858

849859
Ok(())
850860
}
@@ -865,6 +875,15 @@ impl LogClientMessageHandler for LogClientActor {
865875
self.aggregate_window_sec = aggregate_window_sec;
866876
Ok(())
867877
}
878+
879+
async fn flush(
880+
&mut self,
881+
cx: &Context<Self>,
882+
reply: OncePortRef<()>,
883+
) -> Result<(), anyhow::Error> {
884+
self.flush_internal();
885+
reply.send(cx, ()).map_err(anyhow::Error::from)
886+
}
868887
}
869888

870889
#[cfg(test)]

monarch_extension/src/logging.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
#![allow(unsafe_op_in_unsafe_fn)]
1010

1111
use hyperactor::ActorHandle;
12+
use hyperactor::PortRef;
1213
use hyperactor_mesh::RootActorMesh;
1314
use hyperactor_mesh::actor_mesh::ActorMesh;
1415
use hyperactor_mesh::logging::LogClientActor;
1516
use hyperactor_mesh::logging::LogClientMessage;
1617
use hyperactor_mesh::logging::LogForwardActor;
1718
use hyperactor_mesh::logging::LogForwardMessage;
19+
use hyperactor_mesh::logging::LogMessage;
1820
use hyperactor_mesh::selection::Selection;
1921
use hyperactor_mesh::shared_cell::SharedCell;
2022
use monarch_hyperactor::logging::LoggerRuntimeActor;
@@ -90,6 +92,18 @@ impl LoggingMeshClient {
9092

9193
Ok(())
9294
}
95+
96+
fn flush(&self, proc_mesh: &PyProcMesh) -> PyResult<PyPythonTask> {
97+
let (tx, rx) = proc_mesh.try_inner()?.client().open_once_port::<()>();
98+
self.client_actor
99+
.send(LogClientMessage::Flush { reply: tx.bind() })
100+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
101+
PyPythonTask::new(async move {
102+
rx.recv()
103+
.await
104+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
105+
})
106+
}
93107
}
94108

95109
impl Drop for LoggingMeshClient {

python/monarch/_rust_bindings/monarch_extension/logging.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ class LoggingMeshClient:
2121
def set_mode(
2222
self, stream_to_client: bool, aggregate_window_sec: int | None, level: int
2323
) -> None: ...
24+
def flush(self, proc_mesh: ProcMesh) -> PythonTask[None]: ...

python/monarch/_src/actor/proc_mesh.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@
6767
from monarch._src.actor.future import DeprecatedNotAFuture, Future
6868
from monarch._src.actor.shape import MeshTrait
6969

70+
try:
71+
# Check if we are in ipython environment
72+
# pyre-ignore[21]
73+
from IPython import get_ipython
74+
except ImportError:
75+
get_ipython = None
76+
7077
HAS_TENSOR_ENGINE = False
7178
try:
7279
# Torch is needed for tensor engine
@@ -163,7 +170,7 @@ async def _init_manager_actors_coro(
163170
proc_mesh_: "Shared[HyProcMesh]",
164171
setup: Callable[[], None] | None = None,
165172
) -> "HyProcMesh":
166-
proc_mesh = await proc_mesh_
173+
proc_mesh: HyProcMesh = await proc_mesh_
167174
# WARNING: it is unsafe to await self._proc_mesh here
168175
# because self._proc_mesh is the result of this function itself!
169176

@@ -173,6 +180,21 @@ async def _init_manager_actors_coro(
173180
aggregate_window_sec=3,
174181
level=logging.INFO,
175182
)
183+
if get_ipython:
184+
# For ipython environment, a cell can end fast with threads running in background.
185+
# Flush all the ongoing logs proactively to avoid missing logs.
186+
assert self._logging_mesh_client is not None
187+
logging_client: LoggingMeshClient = self._logging_mesh_client
188+
ipython = get_ipython()
189+
190+
# pyre-ignore[21]
191+
from IPython.core.interactiveshell import ExecutionResult
192+
193+
# pyre-ignore[11]
194+
def flush_logs(_: ExecutionResult) -> None:
195+
return Future(coro=logging_client.flush(proc_mesh).spawn().task()).get()
196+
197+
ipython.events.register("post_run_cell", flush_logs)
176198

177199
_rdma_manager = (
178200
# type: ignore[16]

python/tests/test_python_actors.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,109 @@ async def test_logging_option_defaults() -> None:
715715
pass
716716

717717

718+
@pytest.mark.timeout(60)
719+
async def test_flush_logs_ipython() -> None:
720+
"""Test that logs are flushed when get_ipython is available and post_run_cell event is triggered."""
721+
# Save original file descriptors
722+
original_stdout_fd = os.dup(1) # stdout
723+
724+
try:
725+
# Create temporary files to capture output
726+
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
727+
stdout_path = stdout_file.name
728+
729+
# Redirect file descriptors to our temp files
730+
os.dup2(stdout_file.fileno(), 1)
731+
732+
# Also redirect Python's sys.stdout
733+
original_sys_stdout = sys.stdout
734+
sys.stdout = stdout_file
735+
736+
try:
737+
# Mock IPython environment
738+
class MockExecutionResult:
739+
pass
740+
741+
class MockEvents:
742+
def __init__(self):
743+
self.callbacks = {}
744+
745+
def register(self, event_name, callback):
746+
if event_name not in self.callbacks:
747+
self.callbacks[event_name] = []
748+
self.callbacks[event_name].append(callback)
749+
750+
def trigger(self, event_name, *args, **kwargs):
751+
if event_name in self.callbacks:
752+
for callback in self.callbacks[event_name]:
753+
callback(*args, **kwargs)
754+
755+
class MockIPython:
756+
def __init__(self):
757+
self.events = MockEvents()
758+
759+
mock_ipython = MockIPython()
760+
761+
# Patch get_ipython to return our mock using unittest.mock
762+
import unittest.mock
763+
764+
import monarch._src.actor.proc_mesh as proc_mesh_module
765+
766+
with unittest.mock.patch.object(
767+
proc_mesh_module, "get_ipython", lambda: mock_ipython
768+
):
769+
pm = await proc_mesh(gpus=2)
770+
am = await pm.spawn("printer", Printer)
771+
772+
# Set aggregation window to ensure logs are buffered
773+
await pm.logging_option(
774+
stream_to_client=True, aggregate_window_sec=600
775+
)
776+
await asyncio.sleep(1)
777+
778+
# Generate some logs that will be aggregated
779+
for _ in range(5):
780+
await am.print.call("ipython test log")
781+
782+
# Trigger the post_run_cell event which should flush logs
783+
mock_ipython.events.trigger("post_run_cell", MockExecutionResult())
784+
785+
# Wait a bit to ensure flush completes
786+
await asyncio.sleep(1)
787+
788+
# Flush all outputs
789+
stdout_file.flush()
790+
os.fsync(stdout_file.fileno())
791+
792+
finally:
793+
# Restore Python's sys.stdout
794+
sys.stdout = original_sys_stdout
795+
796+
# Restore original file descriptors
797+
os.dup2(original_stdout_fd, 1)
798+
799+
# Read the captured output
800+
with open(stdout_path, "r") as f:
801+
stdout_content = f.read()
802+
803+
# Clean up temp files
804+
os.unlink(stdout_path)
805+
806+
# Verify that logs were flushed when the post_run_cell event was triggered
807+
# We should see the aggregated logs in the output
808+
assert re.search(
809+
r"\[10 similar log lines\].*ipython test log", stdout_content
810+
), stdout_content
811+
812+
finally:
813+
# Ensure file descriptors are restored even if something goes wrong
814+
try:
815+
os.dup2(original_stdout_fd, 1)
816+
os.close(original_stdout_fd)
817+
except OSError:
818+
pass
819+
820+
718821
# oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
719822
@pytest.mark.oss_skip
720823
async def test_flush_logs_fast_exit() -> None:

0 commit comments

Comments
 (0)