From f4e24a6361bf6ffc74efb041e20d05e8570e93f9 Mon Sep 17 00:00:00 2001 From: James Sun Date: Sat, 23 Aug 2025 14:53:07 -0700 Subject: [PATCH] flush log upon ipython notebook cell exit (#816) Summary: Pull Request resolved: https://github.com/meta-pytorch/monarch/pull/816 In ipython notebook, a cell can end fast. Yet the process can still run in the background. However, the background process will not flush logs to the existing cell anymore. The patch registers the flush function upon a cell exiting. Reviewed By: ahmadsharif1 Differential Revision: D79982702 --- python/monarch/_src/actor/logging.py | 94 ++++++++++++++ python/monarch/_src/actor/proc_mesh.py | 23 ++-- python/tests/python_actor_test_binary.py | 2 +- python/tests/test_python_actors.py | 156 ++++++++++++++++++++++- 4 files changed, 255 insertions(+), 20 deletions(-) create mode 100644 python/monarch/_src/actor/logging.py diff --git a/python/monarch/_src/actor/logging.py b/python/monarch/_src/actor/logging.py new file mode 100644 index 000000000..f56003bb5 --- /dev/null +++ b/python/monarch/_src/actor/logging.py @@ -0,0 +1,94 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import gc +import logging + +from typing import Callable + +from monarch._rust_bindings.monarch_extension.logging import LoggingMeshClient + +from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh as HyProcMesh +from monarch._src.actor.future import Future + +IN_IPYTHON = False +try: + # Check if we are in ipython environment + # pyre-ignore[21] + from IPython import get_ipython + + # pyre-ignore[21] + from IPython.core.interactiveshell import ExecutionResult + + IN_IPYTHON = get_ipython() is not None +except ImportError: + pass + + +class LoggingManager: + def __init__(self) -> None: + self._logging_mesh_client: LoggingMeshClient | None = None + self._ipython_flush_logs_handler: Callable[..., None] | None = None + + async def init(self, proc_mesh: HyProcMesh) -> None: + if self._logging_mesh_client is not None: + return + + self._logging_mesh_client = await LoggingMeshClient.spawn(proc_mesh=proc_mesh) + self._logging_mesh_client.set_mode( + stream_to_client=True, + aggregate_window_sec=3, + level=logging.INFO, + ) + + if IN_IPYTHON: + # For ipython environment, a cell can end fast with threads running in background. + # Flush all the ongoing logs proactively to avoid missing logs. + assert self._logging_mesh_client is not None + logging_client: LoggingMeshClient = self._logging_mesh_client + ipython = get_ipython() + + # pyre-ignore[11] + def flush_logs(_: ExecutionResult) -> None: + try: + Future(coro=logging_client.flush().spawn().task()).get(3) + except TimeoutError: + # We need to prevent failed proc meshes not coming back + pass + + # Force to recycle previous undropped proc_mesh. + # Otherwise, we may end up with unregisterd dead callbacks. + gc.collect() + + # Store the handler reference so we can unregister it later + self._ipython_flush_logs_handler = flush_logs + ipython.events.register("post_run_cell", flush_logs) + + async def logging_option( + self, + stream_to_client: bool = True, + aggregate_window_sec: int | None = 3, + level: int = logging.INFO, + ) -> None: + if level < 0 or level > 255: + raise ValueError("Invalid logging level: {}".format(level)) + + assert self._logging_mesh_client is not None + self._logging_mesh_client.set_mode( + stream_to_client=stream_to_client, + aggregate_window_sec=aggregate_window_sec, + level=level, + ) + + def stop(self) -> None: + if self._ipython_flush_logs_handler is not None: + assert IN_IPYTHON + ipython = get_ipython() + assert ipython is not None + ipython.events.unregister("post_run_cell", self._ipython_flush_logs_handler) + self._ipython_flush_logs_handler = None diff --git a/python/monarch/_src/actor/proc_mesh.py b/python/monarch/_src/actor/proc_mesh.py index 263708e07..318e69274 100644 --- a/python/monarch/_src/actor/proc_mesh.py +++ b/python/monarch/_src/actor/proc_mesh.py @@ -33,7 +33,6 @@ ) from weakref import WeakValueDictionary -from monarch._rust_bindings.monarch_extension.logging import LoggingMeshClient from monarch._rust_bindings.monarch_hyperactor.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension Alloc, AllocConstraints, @@ -67,10 +66,12 @@ from monarch._src.actor.endpoint import endpoint from monarch._src.actor.future import DeprecatedNotAFuture, Future +from monarch._src.actor.logging import LoggingManager from monarch._src.actor.shape import MeshTrait from monarch.tools.config import Workspace from monarch.tools.utils import conda as conda_utils + HAS_TENSOR_ENGINE = False try: # Torch is needed for tensor engine @@ -191,7 +192,7 @@ def __init__( # of whether this is a slice of a real proc_meshg self._slice = False self._code_sync_client: Optional[CodeSyncMeshClient] = None - self._logging_mesh_client: Optional[LoggingMeshClient] = None + self._logging_manager: LoggingManager = LoggingManager() self._maybe_device_mesh: Optional["DeviceMesh"] = _device_mesh self._stopped = False self._controller_controller: Optional["_ControllerController"] = None @@ -311,14 +312,7 @@ async def task( ) -> HyProcMesh: hy_proc_mesh = await hy_proc_mesh_task - pm._logging_mesh_client = await LoggingMeshClient.spawn( - proc_mesh=hy_proc_mesh - ) - pm._logging_mesh_client.set_mode( - stream_to_client=True, - aggregate_window_sec=3, - level=logging.INFO, - ) + await pm._logging_manager.init(hy_proc_mesh) if setup_actor is not None: await setup_actor.setup.call() @@ -482,12 +476,9 @@ async def logging_option( Returns: None """ - if level < 0 or level > 255: - raise ValueError("Invalid logging level: {}".format(level)) await self.initialized - assert self._logging_mesh_client is not None - self._logging_mesh_client.set_mode( + await self._logging_manager.logging_option( stream_to_client=stream_to_client, aggregate_window_sec=aggregate_window_sec, level=level, @@ -499,6 +490,8 @@ async def __aenter__(self) -> "ProcMesh": return self def stop(self) -> Future[None]: + self._logging_manager.stop() + async def _stop_nonblocking() -> None: await (await self._proc_mesh).stop_nonblocking() self._stopped = True @@ -516,6 +509,8 @@ async def __aexit__( # Finalizer to check if the proc mesh was closed properly. def __del__(self) -> None: if not self._stopped: + self._logging_manager.stop() + warnings.warn( f"unstopped ProcMesh {self!r}", ResourceWarning, diff --git a/python/tests/python_actor_test_binary.py b/python/tests/python_actor_test_binary.py index 9cff72087..3105bcdd2 100644 --- a/python/tests/python_actor_test_binary.py +++ b/python/tests/python_actor_test_binary.py @@ -42,7 +42,7 @@ async def _flush_logs() -> None: await am.print.call("has print streaming") # TODO: remove this completely once we hook the flush logic upon dropping device_mesh - log_mesh = pm._logging_mesh_client + log_mesh = pm._logging_manager._logging_mesh_client assert log_mesh is not None Future(coro=log_mesh.flush().spawn().task()).get() diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 4c194b909..c0ca1bd0a 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -6,6 +6,7 @@ # pyre-unsafe import asyncio +import gc import importlib.resources import logging import operator @@ -586,7 +587,7 @@ async def test_actor_log_streaming() -> None: await am.log.call("has log streaming as level matched") # TODO: remove this completely once we hook the flush logic upon dropping device_mesh - log_mesh = pm._logging_mesh_client + log_mesh = pm._logging_manager._logging_mesh_client assert log_mesh is not None Future(coro=log_mesh.flush().spawn().task()).get() @@ -705,7 +706,7 @@ async def test_logging_option_defaults() -> None: await am.log.call("log streaming") # TODO: remove this completely once we hook the flush logic upon dropping device_mesh - log_mesh = pm._logging_mesh_client + log_mesh = pm._logging_manager._logging_mesh_client assert log_mesh is not None Future(coro=log_mesh.flush().spawn().task()).get() @@ -760,6 +761,151 @@ async def test_logging_option_defaults() -> None: pass +# oss_skip: pytest keeps complaining about mocking get_ipython module +@pytest.mark.oss_skip +@pytest.mark.timeout(180) +async def test_flush_logs_ipython() -> None: + """Test that logs are flushed when get_ipython is available and post_run_cell event is triggered.""" + # Save original file descriptors + original_stdout_fd = os.dup(1) # stdout + + try: + # Create temporary files to capture output + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file: + stdout_path = stdout_file.name + + # Redirect file descriptors to our temp files + os.dup2(stdout_file.fileno(), 1) + + # Also redirect Python's sys.stdout + original_sys_stdout = sys.stdout + sys.stdout = stdout_file + + try: + # Mock IPython environment + class MockExecutionResult: + pass + + class MockEvents: + def __init__(self): + self.callbacks = {} + self.registers = 0 + self.unregisters = 0 + + def register(self, event_name, callback): + if event_name not in self.callbacks: + self.callbacks[event_name] = [] + self.callbacks[event_name].append(callback) + self.registers += 1 + + def unregister(self, event_name, callback): + if event_name not in self.callbacks: + raise ValueError(f"Event {event_name} not registered") + assert callback in self.callbacks[event_name] + self.callbacks[event_name].remove(callback) + self.unregisters += 1 + + def trigger(self, event_name, *args, **kwargs): + if event_name in self.callbacks: + for callback in self.callbacks[event_name]: + callback(*args, **kwargs) + + class MockIPython: + def __init__(self): + self.events = MockEvents() + + mock_ipython = MockIPython() + + with unittest.mock.patch( + "monarch._src.actor.logging.get_ipython", + lambda: mock_ipython, + ), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True): + # Make sure we can register and unregister callbacks + for i in range(3): + pm1 = await proc_mesh(gpus=2) + pm2 = await proc_mesh(gpus=2) + am1 = await pm1.spawn("printer", Printer) + am2 = await pm2.spawn("printer", Printer) + + # Set aggregation window to ensure logs are buffered + await pm1.logging_option( + stream_to_client=True, aggregate_window_sec=600 + ) + await pm2.logging_option( + stream_to_client=True, aggregate_window_sec=600 + ) + assert mock_ipython.events.unregisters == 2 * i + # TODO: remove `1 +` from attaching controller_controller + assert mock_ipython.events.registers == 1 + 2 * (i + 1) + await asyncio.sleep(1) + + # Generate some logs that will be aggregated + for _ in range(5): + await am1.print.call("ipython1 test log") + await am2.print.call("ipython2 test log") + + # Trigger the post_run_cell event which should flush logs + mock_ipython.events.trigger( + "post_run_cell", MockExecutionResult() + ) + + # Flush all outputs + stdout_file.flush() + os.fsync(stdout_file.fileno()) + + gc.collect() + + # TODO: this should be 6 without attaching controller_controller + assert mock_ipython.events.registers == 7 + # There are many objects still taking refs + assert mock_ipython.events.unregisters == 4 + # TODO: same, this should be 2 + assert len(mock_ipython.events.callbacks["post_run_cell"]) == 3 + finally: + # Restore Python's sys.stdout + sys.stdout = original_sys_stdout + + # Restore original file descriptors + os.dup2(original_stdout_fd, 1) + + # Read the captured output + with open(stdout_path, "r") as f: + stdout_content = f.read() + + # TODO: there are quite a lot of code dups and boilerplate; make them contextmanager utils + + # Clean up temp files + os.unlink(stdout_path) + + # Verify that logs were flushed when the post_run_cell event was triggered + # We should see the aggregated logs in the output + assert ( + len( + re.findall( + r"\[10 similar log lines\].*ipython1 test log", stdout_content + ) + ) + == 3 + ), stdout_content + + assert ( + len( + re.findall( + r"\[10 similar log lines\].*ipython2 test log", stdout_content + ) + ) + == 3 + ), stdout_content + + finally: + # Ensure file descriptors are restored even if something goes wrong + try: + os.dup2(original_stdout_fd, 1) + os.close(original_stdout_fd) + except OSError: + pass + + # oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited @pytest.mark.oss_skip async def test_flush_logs_fast_exit() -> None: @@ -834,7 +980,7 @@ async def test_flush_on_disable_aggregation() -> None: await am.print.call("single log line") # TODO: remove this completely once we hook the flush logic upon dropping device_mesh - log_mesh = pm._logging_mesh_client + log_mesh = pm._logging_manager._logging_mesh_client assert log_mesh is not None Future(coro=log_mesh.flush().spawn().task()).get() @@ -894,7 +1040,7 @@ async def test_multiple_ongoing_flushes_no_deadlock() -> None: for _ in range(10): await am.print.call("aggregated log line") - log_mesh = pm._logging_mesh_client + log_mesh = pm._logging_manager._logging_mesh_client assert log_mesh is not None futures = [] for _ in range(5): @@ -947,7 +1093,7 @@ async def test_adjust_aggregation_window() -> None: await am.print.call("second batch of logs") # TODO: remove this completely once we hook the flush logic upon dropping device_mesh - log_mesh = pm._logging_mesh_client + log_mesh = pm._logging_manager._logging_mesh_client assert log_mesh is not None Future(coro=log_mesh.flush().spawn().task()).get()