Skip to content

Commit 9391ffd

Browse files
James Sunfacebook-github-bot
authored andcommitted
flush log upon ipython notebook cell exit (#816)
Summary: In ipython notebook, a cell can end fast. Yet the process can still run in the background. However, the background process will not flush logs to the existing cell anymore. The patch registers the flush function upon a cell exiting. Differential Revision: D79982702
1 parent e762303 commit 9391ffd

File tree

3 files changed

+235
-16
lines changed

3 files changed

+235
-16
lines changed

python/monarch/_src/actor/logging.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import gc
10+
import logging
11+
12+
from typing import Callable
13+
14+
from monarch._rust_bindings.monarch_extension.logging import LoggingMeshClient
15+
16+
from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh as HyProcMesh
17+
from monarch._src.actor.future import Future
18+
19+
IN_IPYTHON = False
20+
try:
21+
# Check if we are in ipython environment
22+
# pyre-ignore[21]
23+
from IPython import get_ipython
24+
25+
# pyre-ignore[21]
26+
from IPython.core.interactiveshell import ExecutionResult
27+
28+
IN_IPYTHON = get_ipython() is not None
29+
except ImportError:
30+
pass
31+
32+
33+
class LoggingManager:
34+
def __init__(self) -> None:
35+
self._logging_mesh_client: LoggingMeshClient | None = None
36+
self._ipython_flush_logs_handler: Callable[..., None] | None = None
37+
38+
async def init(self, proc_mesh: HyProcMesh) -> None:
39+
if self._logging_mesh_client is not None:
40+
return
41+
42+
self._logging_mesh_client = await LoggingMeshClient.spawn(proc_mesh=proc_mesh)
43+
self._logging_mesh_client.set_mode(
44+
stream_to_client=True,
45+
aggregate_window_sec=3,
46+
level=logging.INFO,
47+
)
48+
49+
if IN_IPYTHON:
50+
# For ipython environment, a cell can end fast with threads running in background.
51+
# Flush all the ongoing logs proactively to avoid missing logs.
52+
assert self._logging_mesh_client is not None
53+
logging_client: LoggingMeshClient = self._logging_mesh_client
54+
ipython = get_ipython()
55+
56+
# pyre-ignore[11]
57+
def flush_logs(_: ExecutionResult) -> None:
58+
try:
59+
Future(coro=logging_client.flush().spawn().task()).get(3)
60+
except TimeoutError:
61+
# We need to prevent failed proc meshes not coming back
62+
pass
63+
64+
# Force to recycle previous undropped proc_mesh.
65+
# Otherwise, we may end up with unregisterd dead callbacks.
66+
gc.collect()
67+
68+
# Store the handler reference so we can unregister it later
69+
self._ipython_flush_logs_handler = flush_logs
70+
ipython.events.register("post_run_cell", flush_logs)
71+
72+
async def logging_option(
73+
self,
74+
stream_to_client: bool = True,
75+
aggregate_window_sec: int | None = 3,
76+
level: int = logging.INFO,
77+
) -> None:
78+
if level < 0 or level > 255:
79+
raise ValueError("Invalid logging level: {}".format(level))
80+
81+
assert self._logging_mesh_client is not None
82+
self._logging_mesh_client.set_mode(
83+
stream_to_client=stream_to_client,
84+
aggregate_window_sec=aggregate_window_sec,
85+
level=level,
86+
)
87+
88+
def stop(self) -> None:
89+
if self._ipython_flush_logs_handler is not None:
90+
assert IN_IPYTHON
91+
ipython = get_ipython()
92+
assert ipython is not None
93+
ipython.events.unregister("post_run_cell", self._ipython_flush_logs_handler)
94+
self._ipython_flush_logs_handler = None

python/monarch/_src/actor/proc_mesh.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@
3131
TypeVar,
3232
)
3333

34-
from monarch._rust_bindings.monarch_extension.logging import LoggingMeshClient
35-
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh
3634
from monarch._rust_bindings.monarch_hyperactor.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension
3735
Alloc,
3836
AllocConstraints,
@@ -71,10 +69,12 @@
7169

7270
from monarch._src.actor.endpoint import endpoint
7371
from monarch._src.actor.future import DeprecatedNotAFuture, Future
72+
from monarch._src.actor.logging import LoggingManager
7473
from monarch._src.actor.shape import MeshTrait
7574
from monarch.tools.config import Workspace
7675
from monarch.tools.utils import conda as conda_utils
7776

77+
7878
HAS_TENSOR_ENGINE = False
7979
try:
8080
# Torch is needed for tensor engine
@@ -154,7 +154,7 @@ def __init__(
154154
self._rdma_manager: Optional["_RdmaManager"] = None
155155
self._debug_manager: Optional[DebugManager] = None
156156
self._code_sync_client: Optional[CodeSyncMeshClient] = None
157-
self._logging_mesh_client: Optional[LoggingMeshClient] = None
157+
self._logging_manager: LoggingManager = LoggingManager()
158158
self._maybe_device_mesh: Optional["DeviceMesh"] = _device_mesh
159159
self._fork_processes = _fork_processes
160160
self._stopped = False
@@ -194,14 +194,7 @@ async def _init_manager_actors_coro(
194194

195195
if _fork_processes:
196196
# logging mesh is only makes sense with forked (remote or local) processes
197-
self._logging_mesh_client = await LoggingMeshClient.spawn(
198-
proc_mesh=proc_mesh
199-
)
200-
self._logging_mesh_client.set_mode(
201-
stream_to_client=True,
202-
aggregate_window_sec=3,
203-
level=logging.INFO,
204-
)
197+
await self._logging_manager.init(proc_mesh)
205198

206199
_rdma_manager = (
207200
# type: ignore[16]
@@ -471,12 +464,9 @@ async def logging_option(
471464
"Logging option is only available for allocators that fork processes. Allocators like LocalAllocator are not supported."
472465
)
473466

474-
if level < 0 or level > 255:
475-
raise ValueError("Invalid logging level: {}".format(level))
476467
await self.initialized
477468

478-
assert self._logging_mesh_client is not None
479-
self._logging_mesh_client.set_mode(
469+
await self._logging_manager.logging_option(
480470
stream_to_client=stream_to_client,
481471
aggregate_window_sec=aggregate_window_sec,
482472
level=level,
@@ -489,6 +479,8 @@ async def __aenter__(self) -> "ProcMesh":
489479

490480
def stop(self) -> Future[None]:
491481
async def _stop_nonblocking() -> None:
482+
self._logging_manager.stop()
483+
492484
await (await self._proc_mesh).stop_nonblocking()
493485
self._stopped = True
494486

@@ -505,6 +497,8 @@ async def __aexit__(
505497
# Finalizer to check if the proc mesh was closed properly.
506498
def __del__(self) -> None:
507499
if not self._stopped:
500+
self._logging_manager.stop()
501+
508502
warnings.warn(
509503
f"unstopped ProcMesh {self!r}",
510504
ResourceWarning,

python/tests/test_python_actors.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,137 @@ async def test_logging_option_defaults() -> None:
718718
pass
719719

720720

721+
# oss_skip: pytest keeps complaining about mocking get_ipython module
722+
@pytest.mark.oss_skip
723+
@pytest.mark.timeout(60)
724+
async def test_flush_logs_ipython() -> None:
725+
"""Test that logs are flushed when get_ipython is available and post_run_cell event is triggered."""
726+
# Save original file descriptors
727+
original_stdout_fd = os.dup(1) # stdout
728+
729+
try:
730+
# Create temporary files to capture output
731+
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
732+
stdout_path = stdout_file.name
733+
734+
# Redirect file descriptors to our temp files
735+
os.dup2(stdout_file.fileno(), 1)
736+
737+
# Also redirect Python's sys.stdout
738+
original_sys_stdout = sys.stdout
739+
sys.stdout = stdout_file
740+
741+
try:
742+
# Mock IPython environment
743+
class MockExecutionResult:
744+
pass
745+
746+
class MockEvents:
747+
def __init__(self):
748+
self.callbacks = {}
749+
750+
def register(self, event_name, callback):
751+
if event_name not in self.callbacks:
752+
self.callbacks[event_name] = []
753+
self.callbacks[event_name].append(callback)
754+
755+
def unregister(self, event_name, callback):
756+
if event_name not in self.callbacks:
757+
raise ValueError(f"Event {event_name} not registered")
758+
assert callback in self.callbacks[event_name]
759+
self.callbacks[event_name].remove(callback)
760+
761+
def trigger(self, event_name, *args, **kwargs):
762+
if event_name in self.callbacks:
763+
for callback in self.callbacks[event_name]:
764+
callback(*args, **kwargs)
765+
766+
class MockIPython:
767+
def __init__(self):
768+
self.events = MockEvents()
769+
770+
mock_ipython = MockIPython()
771+
772+
# Patch get_ipython to return our mock using unittest.mock
773+
import unittest.mock
774+
775+
with unittest.mock.patch(
776+
"monarch._src.actor.logging.get_ipython",
777+
lambda: mock_ipython,
778+
), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True):
779+
# Make sure we can register and unregister callbacks
780+
for _ in range(3):
781+
pm1 = await proc_mesh(gpus=2)
782+
pm2 = await proc_mesh(gpus=2)
783+
am1 = await pm1.spawn("printer", Printer)
784+
am2 = await pm2.spawn("printer", Printer)
785+
786+
# Set aggregation window to ensure logs are buffered
787+
await pm1.logging_option(
788+
stream_to_client=True, aggregate_window_sec=600
789+
)
790+
await pm2.logging_option(
791+
stream_to_client=True, aggregate_window_sec=600
792+
)
793+
await asyncio.sleep(1)
794+
795+
# Generate some logs that will be aggregated
796+
for _ in range(5):
797+
await am1.print.call("ipython1 test log")
798+
await am2.print.call("ipython2 test log")
799+
800+
# Trigger the post_run_cell event which should flush logs
801+
mock_ipython.events.trigger(
802+
"post_run_cell", MockExecutionResult()
803+
)
804+
805+
# Flush all outputs
806+
stdout_file.flush()
807+
os.fsync(stdout_file.fileno())
808+
809+
finally:
810+
# Restore Python's sys.stdout
811+
sys.stdout = original_sys_stdout
812+
813+
# Restore original file descriptors
814+
os.dup2(original_stdout_fd, 1)
815+
816+
# Read the captured output
817+
with open(stdout_path, "r") as f:
818+
stdout_content = f.read()
819+
820+
# Clean up temp files
821+
os.unlink(stdout_path)
822+
823+
# Verify that logs were flushed when the post_run_cell event was triggered
824+
# We should see the aggregated logs in the output
825+
assert (
826+
len(
827+
re.findall(
828+
r"\[10 similar log lines\].*ipython1 test log", stdout_content
829+
)
830+
)
831+
== 3
832+
), stdout_content
833+
834+
assert (
835+
len(
836+
re.findall(
837+
r"\[10 similar log lines\].*ipython2 test log", stdout_content
838+
)
839+
)
840+
== 3
841+
), stdout_content
842+
843+
finally:
844+
# Ensure file descriptors are restored even if something goes wrong
845+
try:
846+
os.dup2(original_stdout_fd, 1)
847+
os.close(original_stdout_fd)
848+
except OSError:
849+
pass
850+
851+
721852
# oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
722853
@pytest.mark.oss_skip
723854
async def test_flush_logs_fast_exit() -> None:
@@ -849,7 +980,7 @@ async def test_multiple_ongoing_flushes_no_deadlock() -> None:
849980
for _ in range(10):
850981
await am.print.call("aggregated log line")
851982

852-
log_mesh = pm._logging_mesh_client
983+
log_mesh = pm._logging_manager._logging_mesh_client
853984
assert log_mesh is not None
854985
futures = []
855986
for _ in range(5):

0 commit comments

Comments
 (0)