Skip to content

Commit 144687a

Browse files
James Sunfacebook-github-bot
authored andcommitted
flush log upon ipython notebook cell exit (#816)
Summary: In ipython notebook, a cell can end fast. Yet the process can still run in the background. However, the background process will not flush logs to the existing cell anymore. The patch registers the flush function upon a cell exiting. Reviewed By: ahmadsharif1 Differential Revision: D79982702
1 parent 3c4cf2e commit 144687a

File tree

3 files changed

+247
-15
lines changed

3 files changed

+247
-15
lines changed

python/monarch/_src/actor/logging.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import gc
10+
import logging
11+
12+
from typing import Callable
13+
14+
from monarch._rust_bindings.monarch_extension.logging import LoggingMeshClient
15+
16+
from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh as HyProcMesh
17+
from monarch._src.actor.future import Future
18+
19+
IN_IPYTHON = False
20+
try:
21+
# Check if we are in ipython environment
22+
# pyre-ignore[21]
23+
from IPython import get_ipython
24+
25+
# pyre-ignore[21]
26+
from IPython.core.interactiveshell import ExecutionResult
27+
28+
IN_IPYTHON = get_ipython() is not None
29+
except ImportError:
30+
pass
31+
32+
33+
class LoggingManager:
34+
def __init__(self) -> None:
35+
self._logging_mesh_client: LoggingMeshClient | None = None
36+
self._ipython_flush_logs_handler: Callable[..., None] | None = None
37+
38+
async def init(self, proc_mesh: HyProcMesh) -> None:
39+
if self._logging_mesh_client is not None:
40+
return
41+
42+
self._logging_mesh_client = await LoggingMeshClient.spawn(proc_mesh=proc_mesh)
43+
self._logging_mesh_client.set_mode(
44+
stream_to_client=True,
45+
aggregate_window_sec=3,
46+
level=logging.INFO,
47+
)
48+
49+
if IN_IPYTHON:
50+
# For ipython environment, a cell can end fast with threads running in background.
51+
# Flush all the ongoing logs proactively to avoid missing logs.
52+
assert self._logging_mesh_client is not None
53+
logging_client: LoggingMeshClient = self._logging_mesh_client
54+
ipython = get_ipython()
55+
56+
# pyre-ignore[11]
57+
def flush_logs(_: ExecutionResult) -> None:
58+
try:
59+
Future(coro=logging_client.flush().spawn().task()).get(3)
60+
except TimeoutError:
61+
# We need to prevent failed proc meshes not coming back
62+
pass
63+
64+
# Force to recycle previous undropped proc_mesh.
65+
# Otherwise, we may end up with unregisterd dead callbacks.
66+
gc.collect()
67+
68+
# Store the handler reference so we can unregister it later
69+
self._ipython_flush_logs_handler = flush_logs
70+
ipython.events.register("post_run_cell", flush_logs)
71+
72+
async def logging_option(
73+
self,
74+
stream_to_client: bool = True,
75+
aggregate_window_sec: int | None = 3,
76+
level: int = logging.INFO,
77+
) -> None:
78+
if level < 0 or level > 255:
79+
raise ValueError("Invalid logging level: {}".format(level))
80+
81+
assert self._logging_mesh_client is not None
82+
self._logging_mesh_client.set_mode(
83+
stream_to_client=stream_to_client,
84+
aggregate_window_sec=aggregate_window_sec,
85+
level=level,
86+
)
87+
88+
def stop(self) -> None:
89+
if self._ipython_flush_logs_handler is not None:
90+
assert IN_IPYTHON
91+
ipython = get_ipython()
92+
assert ipython is not None
93+
ipython.events.unregister("post_run_cell", self._ipython_flush_logs_handler)
94+
self._ipython_flush_logs_handler = None

python/monarch/_src/actor/proc_mesh.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
)
3434
from weakref import WeakValueDictionary
3535

36-
from monarch._rust_bindings.monarch_extension.logging import LoggingMeshClient
3736
from monarch._rust_bindings.monarch_hyperactor.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension
3837
Alloc,
3938
AllocConstraints,
@@ -67,10 +66,12 @@
6766

6867
from monarch._src.actor.endpoint import endpoint
6968
from monarch._src.actor.future import DeprecatedNotAFuture, Future
69+
from monarch._src.actor.logging import LoggingManager
7070
from monarch._src.actor.shape import MeshTrait
7171
from monarch.tools.config import Workspace
7272
from monarch.tools.utils import conda as conda_utils
7373

74+
7475
HAS_TENSOR_ENGINE = False
7576
try:
7677
# Torch is needed for tensor engine
@@ -192,7 +193,7 @@ def __init__(
192193
# of whether this is a slice of a real proc_meshg
193194
self._slice = False
194195
self._code_sync_client: Optional[CodeSyncMeshClient] = None
195-
self._logging_mesh_client: Optional[LoggingMeshClient] = None
196+
self._logging_manager: LoggingManager = LoggingManager()
196197
self._maybe_device_mesh: Optional["DeviceMesh"] = _device_mesh
197198
self._fork_processes = _fork_processes
198199
self._stopped = False
@@ -322,14 +323,8 @@ async def task(
322323
hy_proc_mesh = await hy_proc_mesh_task
323324

324325
if fork_processes:
325-
pm._logging_mesh_client = await LoggingMeshClient.spawn(
326-
proc_mesh=hy_proc_mesh
327-
)
328-
pm._logging_mesh_client.set_mode(
329-
stream_to_client=True,
330-
aggregate_window_sec=3,
331-
level=logging.INFO,
332-
)
326+
# logging mesh is only makes sense with forked (remote or local) processes
327+
await pm._logging_manager.init(hy_proc_mesh)
333328

334329
if setup_actor is not None:
335330
await setup_actor.setup.call()
@@ -501,12 +496,9 @@ async def logging_option(
501496
"Logging option is only available for allocators that fork processes. Allocators like LocalAllocator are not supported."
502497
)
503498

504-
if level < 0 or level > 255:
505-
raise ValueError("Invalid logging level: {}".format(level))
506499
await self.initialized
507500

508-
assert self._logging_mesh_client is not None
509-
self._logging_mesh_client.set_mode(
501+
await self._logging_manager.logging_option(
510502
stream_to_client=stream_to_client,
511503
aggregate_window_sec=aggregate_window_sec,
512504
level=level,
@@ -519,6 +511,8 @@ async def __aenter__(self) -> "ProcMesh":
519511

520512
def stop(self) -> Future[None]:
521513
async def _stop_nonblocking() -> None:
514+
self._logging_manager.stop()
515+
522516
await (await self._proc_mesh).stop_nonblocking()
523517
self._stopped = True
524518

@@ -535,6 +529,8 @@ async def __aexit__(
535529
# Finalizer to check if the proc mesh was closed properly.
536530
def __del__(self) -> None:
537531
if not self._stopped:
532+
self._logging_manager.stop()
533+
538534
warnings.warn(
539535
f"unstopped ProcMesh {self!r}",
540536
ResourceWarning,

python/tests/test_python_actors.py

Lines changed: 143 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-unsafe
88
import asyncio
9+
import gc
910
import importlib.resources
1011
import logging
1112
import operator
@@ -718,6 +719,147 @@ async def test_logging_option_defaults() -> None:
718719
pass
719720

720721

722+
# oss_skip: pytest keeps complaining about mocking get_ipython module
723+
@pytest.mark.oss_skip
724+
@pytest.mark.timeout(180)
725+
async def test_flush_logs_ipython() -> None:
726+
"""Test that logs are flushed when get_ipython is available and post_run_cell event is triggered."""
727+
# Save original file descriptors
728+
original_stdout_fd = os.dup(1) # stdout
729+
730+
try:
731+
# Create temporary files to capture output
732+
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
733+
stdout_path = stdout_file.name
734+
735+
# Redirect file descriptors to our temp files
736+
os.dup2(stdout_file.fileno(), 1)
737+
738+
# Also redirect Python's sys.stdout
739+
original_sys_stdout = sys.stdout
740+
sys.stdout = stdout_file
741+
742+
try:
743+
# Mock IPython environment
744+
class MockExecutionResult:
745+
pass
746+
747+
class MockEvents:
748+
def __init__(self):
749+
self.callbacks = {}
750+
self.registers = 0
751+
self.unregisters = 0
752+
753+
def register(self, event_name, callback):
754+
if event_name not in self.callbacks:
755+
self.callbacks[event_name] = []
756+
self.callbacks[event_name].append(callback)
757+
self.registers += 1
758+
759+
def unregister(self, event_name, callback):
760+
if event_name not in self.callbacks:
761+
raise ValueError(f"Event {event_name} not registered")
762+
assert callback in self.callbacks[event_name]
763+
self.callbacks[event_name].remove(callback)
764+
self.unregisters += 1
765+
766+
def trigger(self, event_name, *args, **kwargs):
767+
if event_name in self.callbacks:
768+
for callback in self.callbacks[event_name]:
769+
callback(*args, **kwargs)
770+
771+
class MockIPython:
772+
def __init__(self):
773+
self.events = MockEvents()
774+
775+
mock_ipython = MockIPython()
776+
777+
with unittest.mock.patch(
778+
"monarch._src.actor.logging.get_ipython",
779+
lambda: mock_ipython,
780+
), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True):
781+
# Make sure we can register and unregister callbacks
782+
for _ in range(3):
783+
pm1 = await proc_mesh(gpus=2)
784+
pm2 = await proc_mesh(gpus=2)
785+
am1 = await pm1.spawn("printer", Printer)
786+
am2 = await pm2.spawn("printer", Printer)
787+
788+
# Set aggregation window to ensure logs are buffered
789+
await pm1.logging_option(
790+
stream_to_client=True, aggregate_window_sec=600
791+
)
792+
await pm2.logging_option(
793+
stream_to_client=True, aggregate_window_sec=600
794+
)
795+
await asyncio.sleep(1)
796+
797+
# Generate some logs that will be aggregated
798+
for _ in range(5):
799+
await am1.print.call("ipython1 test log")
800+
await am2.print.call("ipython2 test log")
801+
802+
# Trigger the post_run_cell event which should flush logs
803+
mock_ipython.events.trigger(
804+
"post_run_cell", MockExecutionResult()
805+
)
806+
807+
gc.collect()
808+
809+
assert mock_ipython.events.registers == 6
810+
# TODO: figure out why the latest unregister is not called
811+
assert mock_ipython.events.unregisters == 4
812+
assert len(mock_ipython.events.callbacks["post_run_cell"]) == 2
813+
814+
# Flush all outputs
815+
stdout_file.flush()
816+
os.fsync(stdout_file.fileno())
817+
818+
finally:
819+
# Restore Python's sys.stdout
820+
sys.stdout = original_sys_stdout
821+
822+
# Restore original file descriptors
823+
os.dup2(original_stdout_fd, 1)
824+
825+
# Read the captured output
826+
with open(stdout_path, "r") as f:
827+
stdout_content = f.read()
828+
829+
# TODO: there are quite a lot of code dups and boilerplate; make them contextmanager utils
830+
831+
# Clean up temp files
832+
os.unlink(stdout_path)
833+
834+
# Verify that logs were flushed when the post_run_cell event was triggered
835+
# We should see the aggregated logs in the output
836+
assert (
837+
len(
838+
re.findall(
839+
r"\[10 similar log lines\].*ipython1 test log", stdout_content
840+
)
841+
)
842+
== 3
843+
), stdout_content
844+
845+
assert (
846+
len(
847+
re.findall(
848+
r"\[10 similar log lines\].*ipython2 test log", stdout_content
849+
)
850+
)
851+
== 3
852+
), stdout_content
853+
854+
finally:
855+
# Ensure file descriptors are restored even if something goes wrong
856+
try:
857+
os.dup2(original_stdout_fd, 1)
858+
os.close(original_stdout_fd)
859+
except OSError:
860+
pass
861+
862+
721863
# oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
722864
@pytest.mark.oss_skip
723865
async def test_flush_logs_fast_exit() -> None:
@@ -849,7 +991,7 @@ async def test_multiple_ongoing_flushes_no_deadlock() -> None:
849991
for _ in range(10):
850992
await am.print.call("aggregated log line")
851993

852-
log_mesh = pm._logging_mesh_client
994+
log_mesh = pm._logging_manager._logging_mesh_client
853995
assert log_mesh is not None
854996
futures = []
855997
for _ in range(5):

0 commit comments

Comments
 (0)