Skip to content

Commit 6c1dcc5

Browse files
James Sunfacebook-github-bot
authored andcommitted
flush log upon ipython notebook cell exit (meta-pytorch#816)
Summary: Pull Request resolved: meta-pytorch#816 In ipython notebook, a cell can end fast. Yet the process can still run in the background. However, the background process will not flush logs to the existing cell anymore. The patch registers the flush function upon a cell exiting. Reviewed By: ahmadsharif1 Differential Revision: D79982702
1 parent d6840e3 commit 6c1dcc5

File tree

3 files changed

+247
-15
lines changed

3 files changed

+247
-15
lines changed

python/monarch/_src/actor/logging.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import gc
10+
import logging
11+
12+
from typing import Callable
13+
14+
from monarch._rust_bindings.monarch_extension.logging import LoggingMeshClient
15+
16+
from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh as HyProcMesh
17+
from monarch._src.actor.future import Future
18+
19+
IN_IPYTHON = False
20+
try:
21+
# Check if we are in ipython environment
22+
# pyre-ignore[21]
23+
from IPython import get_ipython
24+
25+
# pyre-ignore[21]
26+
from IPython.core.interactiveshell import ExecutionResult
27+
28+
IN_IPYTHON = get_ipython() is not None
29+
except ImportError:
30+
pass
31+
32+
33+
class LoggingManager:
34+
def __init__(self) -> None:
35+
self._logging_mesh_client: LoggingMeshClient | None = None
36+
self._ipython_flush_logs_handler: Callable[..., None] | None = None
37+
38+
async def init(self, proc_mesh: HyProcMesh) -> None:
39+
if self._logging_mesh_client is not None:
40+
return
41+
42+
self._logging_mesh_client = await LoggingMeshClient.spawn(proc_mesh=proc_mesh)
43+
self._logging_mesh_client.set_mode(
44+
stream_to_client=True,
45+
aggregate_window_sec=3,
46+
level=logging.INFO,
47+
)
48+
49+
if IN_IPYTHON:
50+
# For ipython environment, a cell can end fast with threads running in background.
51+
# Flush all the ongoing logs proactively to avoid missing logs.
52+
assert self._logging_mesh_client is not None
53+
logging_client: LoggingMeshClient = self._logging_mesh_client
54+
ipython = get_ipython()
55+
56+
# pyre-ignore[11]
57+
def flush_logs(_: ExecutionResult) -> None:
58+
try:
59+
Future(coro=logging_client.flush().spawn().task()).get(3)
60+
except TimeoutError:
61+
# We need to prevent failed proc meshes not coming back
62+
pass
63+
64+
# Force to recycle previous undropped proc_mesh.
65+
# Otherwise, we may end up with unregisterd dead callbacks.
66+
gc.collect()
67+
68+
# Store the handler reference so we can unregister it later
69+
self._ipython_flush_logs_handler = flush_logs
70+
ipython.events.register("post_run_cell", flush_logs)
71+
72+
async def logging_option(
73+
self,
74+
stream_to_client: bool = True,
75+
aggregate_window_sec: int | None = 3,
76+
level: int = logging.INFO,
77+
) -> None:
78+
if level < 0 or level > 255:
79+
raise ValueError("Invalid logging level: {}".format(level))
80+
81+
assert self._logging_mesh_client is not None
82+
self._logging_mesh_client.set_mode(
83+
stream_to_client=stream_to_client,
84+
aggregate_window_sec=aggregate_window_sec,
85+
level=level,
86+
)
87+
88+
def stop(self) -> None:
89+
if self._ipython_flush_logs_handler is not None:
90+
assert IN_IPYTHON
91+
ipython = get_ipython()
92+
assert ipython is not None
93+
ipython.events.unregister("post_run_cell", self._ipython_flush_logs_handler)
94+
self._ipython_flush_logs_handler = None

python/monarch/_src/actor/proc_mesh.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
)
3434
from weakref import WeakValueDictionary
3535

36-
from monarch._rust_bindings.monarch_extension.logging import LoggingMeshClient
3736
from monarch._rust_bindings.monarch_hyperactor.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension
3837
Alloc,
3938
AllocConstraints,
@@ -67,10 +66,12 @@
6766

6867
from monarch._src.actor.endpoint import endpoint
6968
from monarch._src.actor.future import DeprecatedNotAFuture, Future
69+
from monarch._src.actor.logging import LoggingManager
7070
from monarch._src.actor.shape import MeshTrait
7171
from monarch.tools.config import Workspace
7272
from monarch.tools.utils import conda as conda_utils
7373

74+
7475
HAS_TENSOR_ENGINE = False
7576
try:
7677
# Torch is needed for tensor engine
@@ -191,7 +192,7 @@ def __init__(
191192
# of whether this is a slice of a real proc_meshg
192193
self._slice = False
193194
self._code_sync_client: Optional[CodeSyncMeshClient] = None
194-
self._logging_mesh_client: Optional[LoggingMeshClient] = None
195+
self._logging_manager: LoggingManager = LoggingManager()
195196
self._maybe_device_mesh: Optional["DeviceMesh"] = _device_mesh
196197
self._stopped = False
197198
self._controller_controller: Optional["_ControllerController"] = None
@@ -309,14 +310,8 @@ async def task(
309310
) -> HyProcMesh:
310311
hy_proc_mesh = await hy_proc_mesh_task
311312

312-
pm._logging_mesh_client = await LoggingMeshClient.spawn(
313-
proc_mesh=hy_proc_mesh
314-
)
315-
pm._logging_mesh_client.set_mode(
316-
stream_to_client=True,
317-
aggregate_window_sec=3,
318-
level=logging.INFO,
319-
)
313+
# logging mesh is only makes sense with forked (remote or local) processes
314+
await pm._logging_manager.init(hy_proc_mesh)
320315

321316
if setup_actor is not None:
322317
await setup_actor.setup.call()
@@ -483,12 +478,9 @@ async def logging_option(
483478
Returns:
484479
None
485480
"""
486-
if level < 0 or level > 255:
487-
raise ValueError("Invalid logging level: {}".format(level))
488481
await self.initialized
489482

490-
assert self._logging_mesh_client is not None
491-
self._logging_mesh_client.set_mode(
483+
await self._logging_manager.logging_option(
492484
stream_to_client=stream_to_client,
493485
aggregate_window_sec=aggregate_window_sec,
494486
level=level,
@@ -501,6 +493,8 @@ async def __aenter__(self) -> "ProcMesh":
501493

502494
def stop(self) -> Future[None]:
503495
async def _stop_nonblocking() -> None:
496+
self._logging_manager.stop()
497+
504498
await (await self._proc_mesh).stop_nonblocking()
505499
self._stopped = True
506500

@@ -517,6 +511,8 @@ async def __aexit__(
517511
# Finalizer to check if the proc mesh was closed properly.
518512
def __del__(self) -> None:
519513
if not self._stopped:
514+
self._logging_manager.stop()
515+
520516
warnings.warn(
521517
f"unstopped ProcMesh {self!r}",
522518
ResourceWarning,

python/tests/test_python_actors.py

Lines changed: 143 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-unsafe
88
import asyncio
9+
import gc
910
import importlib.resources
1011
import logging
1112
import operator
@@ -729,6 +730,147 @@ async def test_logging_option_defaults() -> None:
729730
pass
730731

731732

733+
# oss_skip: pytest keeps complaining about mocking get_ipython module
734+
@pytest.mark.oss_skip
735+
@pytest.mark.timeout(180)
736+
async def test_flush_logs_ipython() -> None:
737+
"""Test that logs are flushed when get_ipython is available and post_run_cell event is triggered."""
738+
# Save original file descriptors
739+
original_stdout_fd = os.dup(1) # stdout
740+
741+
try:
742+
# Create temporary files to capture output
743+
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
744+
stdout_path = stdout_file.name
745+
746+
# Redirect file descriptors to our temp files
747+
os.dup2(stdout_file.fileno(), 1)
748+
749+
# Also redirect Python's sys.stdout
750+
original_sys_stdout = sys.stdout
751+
sys.stdout = stdout_file
752+
753+
try:
754+
# Mock IPython environment
755+
class MockExecutionResult:
756+
pass
757+
758+
class MockEvents:
759+
def __init__(self):
760+
self.callbacks = {}
761+
self.registers = 0
762+
self.unregisters = 0
763+
764+
def register(self, event_name, callback):
765+
if event_name not in self.callbacks:
766+
self.callbacks[event_name] = []
767+
self.callbacks[event_name].append(callback)
768+
self.registers += 1
769+
770+
def unregister(self, event_name, callback):
771+
if event_name not in self.callbacks:
772+
raise ValueError(f"Event {event_name} not registered")
773+
assert callback in self.callbacks[event_name]
774+
self.callbacks[event_name].remove(callback)
775+
self.unregisters += 1
776+
777+
def trigger(self, event_name, *args, **kwargs):
778+
if event_name in self.callbacks:
779+
for callback in self.callbacks[event_name]:
780+
callback(*args, **kwargs)
781+
782+
class MockIPython:
783+
def __init__(self):
784+
self.events = MockEvents()
785+
786+
mock_ipython = MockIPython()
787+
788+
with unittest.mock.patch(
789+
"monarch._src.actor.logging.get_ipython",
790+
lambda: mock_ipython,
791+
), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True):
792+
# Make sure we can register and unregister callbacks
793+
for _ in range(3):
794+
pm1 = await proc_mesh(gpus=2)
795+
pm2 = await proc_mesh(gpus=2)
796+
am1 = await pm1.spawn("printer", Printer)
797+
am2 = await pm2.spawn("printer", Printer)
798+
799+
# Set aggregation window to ensure logs are buffered
800+
await pm1.logging_option(
801+
stream_to_client=True, aggregate_window_sec=600
802+
)
803+
await pm2.logging_option(
804+
stream_to_client=True, aggregate_window_sec=600
805+
)
806+
await asyncio.sleep(1)
807+
808+
# Generate some logs that will be aggregated
809+
for _ in range(5):
810+
await am1.print.call("ipython1 test log")
811+
await am2.print.call("ipython2 test log")
812+
813+
# Trigger the post_run_cell event which should flush logs
814+
mock_ipython.events.trigger(
815+
"post_run_cell", MockExecutionResult()
816+
)
817+
818+
gc.collect()
819+
820+
assert mock_ipython.events.registers == 6
821+
# TODO: figure out why the latest unregister is not called
822+
assert mock_ipython.events.unregisters == 4
823+
assert len(mock_ipython.events.callbacks["post_run_cell"]) == 2
824+
825+
# Flush all outputs
826+
stdout_file.flush()
827+
os.fsync(stdout_file.fileno())
828+
829+
finally:
830+
# Restore Python's sys.stdout
831+
sys.stdout = original_sys_stdout
832+
833+
# Restore original file descriptors
834+
os.dup2(original_stdout_fd, 1)
835+
836+
# Read the captured output
837+
with open(stdout_path, "r") as f:
838+
stdout_content = f.read()
839+
840+
# TODO: there are quite a lot of code dups and boilerplate; make them contextmanager utils
841+
842+
# Clean up temp files
843+
os.unlink(stdout_path)
844+
845+
# Verify that logs were flushed when the post_run_cell event was triggered
846+
# We should see the aggregated logs in the output
847+
assert (
848+
len(
849+
re.findall(
850+
r"\[10 similar log lines\].*ipython1 test log", stdout_content
851+
)
852+
)
853+
== 3
854+
), stdout_content
855+
856+
assert (
857+
len(
858+
re.findall(
859+
r"\[10 similar log lines\].*ipython2 test log", stdout_content
860+
)
861+
)
862+
== 3
863+
), stdout_content
864+
865+
finally:
866+
# Ensure file descriptors are restored even if something goes wrong
867+
try:
868+
os.dup2(original_stdout_fd, 1)
869+
os.close(original_stdout_fd)
870+
except OSError:
871+
pass
872+
873+
732874
# oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
733875
@pytest.mark.oss_skip
734876
async def test_flush_logs_fast_exit() -> None:
@@ -860,7 +1002,7 @@ async def test_multiple_ongoing_flushes_no_deadlock() -> None:
8601002
for _ in range(10):
8611003
await am.print.call("aggregated log line")
8621004

863-
log_mesh = pm._logging_mesh_client
1005+
log_mesh = pm._logging_manager._logging_mesh_client
8641006
assert log_mesh is not None
8651007
futures = []
8661008
for _ in range(5):

0 commit comments

Comments
 (0)