Skip to content

Commit 47fb393

Browse files
highkerfacebook-github-bot
authored andcommitted
check if we are in ipython env
Summary: In ipython notebook, a cell can end fast. It may lose unflushed logs unable to print to stdout/stderr. The patch register the flush function upon a cell exiting. Differential Revision: D79982702
1 parent af48096 commit 47fb393

File tree

2 files changed

+124
-1
lines changed

2 files changed

+124
-1
lines changed

python/monarch/_src/actor/proc_mesh.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,16 @@
6767
from monarch._src.actor.future import DeprecatedNotAFuture, Future
6868
from monarch._src.actor.shape import MeshTrait
6969

70+
HAS_IPYTHON = False
71+
try:
72+
# Check if we are in ipython environment
73+
# pyre-ignore[21]
74+
from IPython import get_ipython
75+
76+
HAS_IPYTHON = True
77+
except ImportError:
78+
pass
79+
7080
HAS_TENSOR_ENGINE = False
7181
try:
7282
# Torch is needed for tensor engine
@@ -163,7 +173,7 @@ async def _init_manager_actors_coro(
163173
proc_mesh_: "Shared[HyProcMesh]",
164174
setup: Callable[[], None] | None = None,
165175
) -> "HyProcMesh":
166-
proc_mesh = await proc_mesh_
176+
proc_mesh: HyProcMesh = await proc_mesh_
167177
# WARNING: it is unsafe to await self._proc_mesh here
168178
# because self._proc_mesh is the result of this function itself!
169179

@@ -173,6 +183,21 @@ async def _init_manager_actors_coro(
173183
aggregate_window_sec=3,
174184
level=logging.INFO,
175185
)
186+
if HAS_IPYTHON and get_ipython() is not None:
187+
# For ipython environment, a cell can end fast with threads running in background.
188+
# Flush all the ongoing logs proactively to avoid missing logs.
189+
assert self._logging_mesh_client is not None
190+
logging_client: LoggingMeshClient = self._logging_mesh_client
191+
ipython = get_ipython()
192+
193+
# pyre-ignore[21]
194+
from IPython.core.interactiveshell import ExecutionResult
195+
196+
# pyre-ignore[11]
197+
def flush_logs(_: ExecutionResult) -> None:
198+
return Future(coro=logging_client.flush(proc_mesh).spawn().task()).get()
199+
200+
ipython.events.register("post_run_cell", flush_logs)
176201

177202
_rdma_manager = (
178203
# type: ignore[16]

python/tests/test_python_actors.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,104 @@ async def test_logging_option_defaults() -> None:
724724
pass
725725

726726

727+
@pytest.mark.timeout(60)
728+
async def test_flush_logs_ipython() -> None:
729+
"""Test that logs are flushed when get_ipython is available and post_run_cell event is triggered."""
730+
# Save original file descriptors
731+
original_stdout_fd = os.dup(1) # stdout
732+
733+
try:
734+
# Create temporary files to capture output
735+
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
736+
stdout_path = stdout_file.name
737+
738+
# Redirect file descriptors to our temp files
739+
os.dup2(stdout_file.fileno(), 1)
740+
741+
# Also redirect Python's sys.stdout
742+
original_sys_stdout = sys.stdout
743+
sys.stdout = stdout_file
744+
745+
try:
746+
# Mock IPython environment
747+
class MockExecutionResult:
748+
pass
749+
750+
class MockEvents:
751+
def __init__(self):
752+
self.callbacks = {}
753+
754+
def register(self, event_name, callback):
755+
if event_name not in self.callbacks:
756+
self.callbacks[event_name] = []
757+
self.callbacks[event_name].append(callback)
758+
759+
def trigger(self, event_name, *args, **kwargs):
760+
if event_name in self.callbacks:
761+
for callback in self.callbacks[event_name]:
762+
callback(*args, **kwargs)
763+
764+
class MockIPython:
765+
def __init__(self):
766+
self.events = MockEvents()
767+
768+
mock_ipython = MockIPython()
769+
770+
# Patch get_ipython to return our mock using unittest.mock
771+
import unittest.mock
772+
773+
with unittest.mock.patch(
774+
"monarch._src.actor.proc_mesh.get_ipython", lambda: mock_ipython
775+
):
776+
pm = await proc_mesh(gpus=2)
777+
am = await pm.spawn("printer", Printer)
778+
779+
# Set aggregation window to ensure logs are buffered
780+
await pm.logging_option(
781+
stream_to_client=True, aggregate_window_sec=600
782+
)
783+
await asyncio.sleep(1)
784+
785+
# Generate some logs that will be aggregated
786+
for _ in range(5):
787+
await am.print.call("ipython test log")
788+
789+
# Trigger the post_run_cell event which should flush logs
790+
mock_ipython.events.trigger("post_run_cell", MockExecutionResult())
791+
792+
# Flush all outputs
793+
stdout_file.flush()
794+
os.fsync(stdout_file.fileno())
795+
796+
finally:
797+
# Restore Python's sys.stdout
798+
sys.stdout = original_sys_stdout
799+
800+
# Restore original file descriptors
801+
os.dup2(original_stdout_fd, 1)
802+
803+
# Read the captured output
804+
with open(stdout_path, "r") as f:
805+
stdout_content = f.read()
806+
807+
# Clean up temp files
808+
os.unlink(stdout_path)
809+
810+
# Verify that logs were flushed when the post_run_cell event was triggered
811+
# We should see the aggregated logs in the output
812+
assert re.search(
813+
r"\[10 similar log lines\].*ipython test log", stdout_content
814+
), stdout_content
815+
816+
finally:
817+
# Ensure file descriptors are restored even if something goes wrong
818+
try:
819+
os.dup2(original_stdout_fd, 1)
820+
os.close(original_stdout_fd)
821+
except OSError:
822+
pass
823+
824+
727825
# oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
728826
@pytest.mark.oss_skip
729827
async def test_flush_logs_fast_exit() -> None:

0 commit comments

Comments
 (0)