Skip to content

Commit e813a77

Browse files
James Sunfacebook-github-bot
authored andcommitted
flush log upon ipython notebook cell exit (#816)
Summary: Pull Request resolved: #816 In ipython notebook, a cell can end fast. Yet the process can still run in the background. However, the background process will not flush logs to the existing cell anymore. The patch registers the flush function upon a cell exiting. Differential Revision: D79982702
1 parent ab890a3 commit e813a77

File tree

3 files changed

+244
-16
lines changed

3 files changed

+244
-16
lines changed

python/monarch/_src/actor/logging.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import gc
10+
import logging
11+
12+
from typing import Callable
13+
14+
from monarch._rust_bindings.monarch_extension.logging import LoggingMeshClient
15+
16+
from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh as HyProcMesh
17+
from monarch._src.actor.future import Future
18+
19+
IN_IPYTHON = False
20+
try:
21+
# Check if we are in ipython environment
22+
# pyre-ignore[21]
23+
from IPython import get_ipython
24+
25+
# pyre-ignore[21]
26+
from IPython.core.interactiveshell import ExecutionResult
27+
28+
IN_IPYTHON = get_ipython() is not None
29+
except ImportError:
30+
pass
31+
32+
33+
class LoggingManager:
34+
def __init__(self) -> None:
35+
self._logging_mesh_client: LoggingMeshClient | None = None
36+
self._ipython_flush_logs_handler: Callable[..., None] | None = None
37+
38+
async def init(self, proc_mesh: HyProcMesh) -> None:
39+
if self._logging_mesh_client is not None:
40+
return
41+
42+
self._logging_mesh_client = await LoggingMeshClient.spawn(proc_mesh=proc_mesh)
43+
self._logging_mesh_client.set_mode(
44+
stream_to_client=True,
45+
aggregate_window_sec=3,
46+
level=logging.INFO,
47+
)
48+
49+
if IN_IPYTHON:
50+
# For ipython environment, a cell can end fast with threads running in background.
51+
# Flush all the ongoing logs proactively to avoid missing logs.
52+
assert self._logging_mesh_client is not None
53+
logging_client: LoggingMeshClient = self._logging_mesh_client
54+
ipython = get_ipython()
55+
56+
# pyre-ignore[11]
57+
def flush_logs(_: ExecutionResult) -> None:
58+
try:
59+
Future(coro=logging_client.flush().spawn().task()).get(3)
60+
except TimeoutError:
61+
# We need to prevent failed proc meshes not coming back
62+
pass
63+
64+
# Force to recycle previous undropped proc_mesh.
65+
# Otherwise, we may end up with unregisterd dead callbacks.
66+
gc.collect()
67+
68+
# Store the handler reference so we can unregister it later
69+
self._ipython_flush_logs_handler = flush_logs
70+
ipython.events.register("post_run_cell", flush_logs)
71+
72+
async def logging_option(
73+
self,
74+
stream_to_client: bool = True,
75+
aggregate_window_sec: int | None = 3,
76+
level: int = logging.INFO,
77+
) -> None:
78+
if level < 0 or level > 255:
79+
raise ValueError("Invalid logging level: {}".format(level))
80+
81+
assert self._logging_mesh_client is not None
82+
self._logging_mesh_client.set_mode(
83+
stream_to_client=stream_to_client,
84+
aggregate_window_sec=aggregate_window_sec,
85+
level=level,
86+
)
87+
88+
def stop(self) -> None:
89+
if self._ipython_flush_logs_handler is not None:
90+
assert IN_IPYTHON
91+
ipython = get_ipython()
92+
assert ipython is not None
93+
ipython.events.unregister("post_run_cell", self._ipython_flush_logs_handler)
94+
self._ipython_flush_logs_handler = None

python/monarch/_src/actor/proc_mesh.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@
3131
TypeVar,
3232
)
3333

34-
from monarch._rust_bindings.monarch_extension.logging import LoggingMeshClient
35-
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh
3634
from monarch._rust_bindings.monarch_hyperactor.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension
3735
Alloc,
3836
AllocConstraints,
@@ -71,10 +69,12 @@
7169

7270
from monarch._src.actor.endpoint import endpoint
7371
from monarch._src.actor.future import DeprecatedNotAFuture, Future
72+
from monarch._src.actor.logging import LoggingManager
7473
from monarch._src.actor.shape import MeshTrait
7574
from monarch.tools.config import Workspace
7675
from monarch.tools.utils import conda as conda_utils
7776

77+
7878
HAS_TENSOR_ENGINE = False
7979
try:
8080
# Torch is needed for tensor engine
@@ -154,7 +154,7 @@ def __init__(
154154
self._rdma_manager: Optional["_RdmaManager"] = None
155155
self._debug_manager: Optional[DebugManager] = None
156156
self._code_sync_client: Optional[CodeSyncMeshClient] = None
157-
self._logging_mesh_client: Optional[LoggingMeshClient] = None
157+
self._logging_manager: LoggingManager = LoggingManager()
158158
self._maybe_device_mesh: Optional["DeviceMesh"] = _device_mesh
159159
self._fork_processes = _fork_processes
160160
self._stopped = False
@@ -194,14 +194,7 @@ async def _init_manager_actors_coro(
194194

195195
if _fork_processes:
196196
# logging mesh is only makes sense with forked (remote or local) processes
197-
self._logging_mesh_client = await LoggingMeshClient.spawn(
198-
proc_mesh=proc_mesh
199-
)
200-
self._logging_mesh_client.set_mode(
201-
stream_to_client=True,
202-
aggregate_window_sec=3,
203-
level=logging.INFO,
204-
)
197+
await self._logging_manager.init(proc_mesh)
205198

206199
_rdma_manager = (
207200
# type: ignore[16]
@@ -471,12 +464,9 @@ async def logging_option(
471464
"Logging option is only available for allocators that fork processes. Allocators like LocalAllocator are not supported."
472465
)
473466

474-
if level < 0 or level > 255:
475-
raise ValueError("Invalid logging level: {}".format(level))
476467
await self.initialized
477468

478-
assert self._logging_mesh_client is not None
479-
self._logging_mesh_client.set_mode(
469+
await self._logging_manager.logging_option(
480470
stream_to_client=stream_to_client,
481471
aggregate_window_sec=aggregate_window_sec,
482472
level=level,
@@ -489,6 +479,8 @@ async def __aenter__(self) -> "ProcMesh":
489479

490480
def stop(self) -> Future[None]:
491481
async def _stop_nonblocking() -> None:
482+
self._logging_manager.stop()
483+
492484
await (await self._proc_mesh).stop_nonblocking()
493485
self._stopped = True
494486

@@ -505,6 +497,8 @@ async def __aexit__(
505497
# Finalizer to check if the proc mesh was closed properly.
506498
def __del__(self) -> None:
507499
if not self._stopped:
500+
self._logging_manager.stop()
501+
508502
warnings.warn(
509503
f"unstopped ProcMesh {self!r}",
510504
ResourceWarning,

python/tests/test_python_actors.py

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-unsafe
88
import asyncio
9+
import gc
910
import importlib.resources
1011
import logging
1112
import operator
@@ -718,6 +719,145 @@ async def test_logging_option_defaults() -> None:
718719
pass
719720

720721

722+
# oss_skip: pytest keeps complaining about mocking get_ipython module
723+
@pytest.mark.oss_skip
724+
@pytest.mark.timeout(180)
725+
async def test_flush_logs_ipython() -> None:
726+
"""Test that logs are flushed when get_ipython is available and post_run_cell event is triggered."""
727+
# Save original file descriptors
728+
original_stdout_fd = os.dup(1) # stdout
729+
730+
try:
731+
# Create temporary files to capture output
732+
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
733+
stdout_path = stdout_file.name
734+
735+
# Redirect file descriptors to our temp files
736+
os.dup2(stdout_file.fileno(), 1)
737+
738+
# Also redirect Python's sys.stdout
739+
original_sys_stdout = sys.stdout
740+
sys.stdout = stdout_file
741+
742+
try:
743+
# Mock IPython environment
744+
class MockExecutionResult:
745+
pass
746+
747+
class MockEvents:
748+
def __init__(self):
749+
self.callbacks = {}
750+
self.registers = 0
751+
self.unregisters = 0
752+
753+
def register(self, event_name, callback):
754+
if event_name not in self.callbacks:
755+
self.callbacks[event_name] = []
756+
self.callbacks[event_name].append(callback)
757+
self.registers += 1
758+
759+
def unregister(self, event_name, callback):
760+
if event_name not in self.callbacks:
761+
raise ValueError(f"Event {event_name} not registered")
762+
assert callback in self.callbacks[event_name]
763+
self.callbacks[event_name].remove(callback)
764+
self.unregisters += 1
765+
766+
def trigger(self, event_name, *args, **kwargs):
767+
if event_name in self.callbacks:
768+
for callback in self.callbacks[event_name]:
769+
callback(*args, **kwargs)
770+
771+
class MockIPython:
772+
def __init__(self):
773+
self.events = MockEvents()
774+
775+
mock_ipython = MockIPython()
776+
777+
with unittest.mock.patch(
778+
"monarch._src.actor.logging.get_ipython",
779+
lambda: mock_ipython,
780+
), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True):
781+
# Make sure we can register and unregister callbacks
782+
for _ in range(3):
783+
pm1 = await proc_mesh(gpus=2)
784+
pm2 = await proc_mesh(gpus=2)
785+
am1 = await pm1.spawn("printer", Printer)
786+
am2 = await pm2.spawn("printer", Printer)
787+
788+
# Set aggregation window to ensure logs are buffered
789+
await pm1.logging_option(
790+
stream_to_client=True, aggregate_window_sec=600
791+
)
792+
await pm2.logging_option(
793+
stream_to_client=True, aggregate_window_sec=600
794+
)
795+
await asyncio.sleep(1)
796+
797+
# Generate some logs that will be aggregated
798+
for _ in range(5):
799+
await am1.print.call("ipython1 test log")
800+
await am2.print.call("ipython2 test log")
801+
802+
# Trigger the post_run_cell event which should flush logs
803+
mock_ipython.events.trigger(
804+
"post_run_cell", MockExecutionResult()
805+
)
806+
807+
gc.collect()
808+
809+
assert mock_ipython.events.registers == 6
810+
# TODO: figure out why the latest unregister is not called
811+
assert mock_ipython.events.unregisters == 4
812+
assert len(mock_ipython.events.callbacks["post_run_cell"]) == 2
813+
814+
# Flush all outputs
815+
stdout_file.flush()
816+
os.fsync(stdout_file.fileno())
817+
818+
finally:
819+
# Restore Python's sys.stdout
820+
sys.stdout = original_sys_stdout
821+
822+
# Restore original file descriptors
823+
os.dup2(original_stdout_fd, 1)
824+
825+
# Read the captured output
826+
with open(stdout_path, "r") as f:
827+
stdout_content = f.read()
828+
829+
# Clean up temp files
830+
os.unlink(stdout_path)
831+
832+
# Verify that logs were flushed when the post_run_cell event was triggered
833+
# We should see the aggregated logs in the output
834+
assert (
835+
len(
836+
re.findall(
837+
r"\[10 similar log lines\].*ipython1 test log", stdout_content
838+
)
839+
)
840+
== 3
841+
), stdout_content
842+
843+
assert (
844+
len(
845+
re.findall(
846+
r"\[10 similar log lines\].*ipython2 test log", stdout_content
847+
)
848+
)
849+
== 3
850+
), stdout_content
851+
852+
finally:
853+
# Ensure file descriptors are restored even if something goes wrong
854+
try:
855+
os.dup2(original_stdout_fd, 1)
856+
os.close(original_stdout_fd)
857+
except OSError:
858+
pass
859+
860+
721861
# oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
722862
@pytest.mark.oss_skip
723863
async def test_flush_logs_fast_exit() -> None:
@@ -849,7 +989,7 @@ async def test_multiple_ongoing_flushes_no_deadlock() -> None:
849989
for _ in range(10):
850990
await am.print.call("aggregated log line")
851991

852-
log_mesh = pm._logging_mesh_client
992+
log_mesh = pm._logging_manager._logging_mesh_client
853993
assert log_mesh is not None
854994
futures = []
855995
for _ in range(5):

0 commit comments

Comments
 (0)