|
6 | 6 |
|
7 | 7 | # pyre-unsafe
|
8 | 8 | import asyncio
|
| 9 | +import gc |
9 | 10 | import importlib.resources
|
10 | 11 | import logging
|
11 | 12 | import operator
|
@@ -718,6 +719,145 @@ async def test_logging_option_defaults() -> None:
|
718 | 719 | pass
|
719 | 720 |
|
720 | 721 |
|
| 722 | +# oss_skip: pytest keeps complaining about mocking get_ipython module |
| 723 | +@pytest.mark.oss_skip |
| 724 | +@pytest.mark.timeout(180) |
| 725 | +async def test_flush_logs_ipython() -> None: |
| 726 | + """Test that logs are flushed when get_ipython is available and post_run_cell event is triggered.""" |
| 727 | + # Save original file descriptors |
| 728 | + original_stdout_fd = os.dup(1) # stdout |
| 729 | + |
| 730 | + try: |
| 731 | + # Create temporary files to capture output |
| 732 | + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file: |
| 733 | + stdout_path = stdout_file.name |
| 734 | + |
| 735 | + # Redirect file descriptors to our temp files |
| 736 | + os.dup2(stdout_file.fileno(), 1) |
| 737 | + |
| 738 | + # Also redirect Python's sys.stdout |
| 739 | + original_sys_stdout = sys.stdout |
| 740 | + sys.stdout = stdout_file |
| 741 | + |
| 742 | + try: |
| 743 | + # Mock IPython environment |
| 744 | + class MockExecutionResult: |
| 745 | + pass |
| 746 | + |
| 747 | + class MockEvents: |
| 748 | + def __init__(self): |
| 749 | + self.callbacks = {} |
| 750 | + self.registers = 0 |
| 751 | + self.unregisters = 0 |
| 752 | + |
| 753 | + def register(self, event_name, callback): |
| 754 | + if event_name not in self.callbacks: |
| 755 | + self.callbacks[event_name] = [] |
| 756 | + self.callbacks[event_name].append(callback) |
| 757 | + self.registers += 1 |
| 758 | + |
| 759 | + def unregister(self, event_name, callback): |
| 760 | + if event_name not in self.callbacks: |
| 761 | + raise ValueError(f"Event {event_name} not registered") |
| 762 | + assert callback in self.callbacks[event_name] |
| 763 | + self.callbacks[event_name].remove(callback) |
| 764 | + self.unregisters += 1 |
| 765 | + |
| 766 | + def trigger(self, event_name, *args, **kwargs): |
| 767 | + if event_name in self.callbacks: |
| 768 | + for callback in self.callbacks[event_name]: |
| 769 | + callback(*args, **kwargs) |
| 770 | + |
| 771 | + class MockIPython: |
| 772 | + def __init__(self): |
| 773 | + self.events = MockEvents() |
| 774 | + |
| 775 | + mock_ipython = MockIPython() |
| 776 | + |
| 777 | + with unittest.mock.patch( |
| 778 | + "monarch._src.actor.logging.get_ipython", |
| 779 | + lambda: mock_ipython, |
| 780 | + ), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True): |
| 781 | + # Make sure we can register and unregister callbacks |
| 782 | + for _ in range(3): |
| 783 | + pm1 = await proc_mesh(gpus=2) |
| 784 | + pm2 = await proc_mesh(gpus=2) |
| 785 | + am1 = await pm1.spawn("printer", Printer) |
| 786 | + am2 = await pm2.spawn("printer", Printer) |
| 787 | + |
| 788 | + # Set aggregation window to ensure logs are buffered |
| 789 | + await pm1.logging_option( |
| 790 | + stream_to_client=True, aggregate_window_sec=600 |
| 791 | + ) |
| 792 | + await pm2.logging_option( |
| 793 | + stream_to_client=True, aggregate_window_sec=600 |
| 794 | + ) |
| 795 | + await asyncio.sleep(1) |
| 796 | + |
| 797 | + # Generate some logs that will be aggregated |
| 798 | + for _ in range(5): |
| 799 | + await am1.print.call("ipython1 test log") |
| 800 | + await am2.print.call("ipython2 test log") |
| 801 | + |
| 802 | + # Trigger the post_run_cell event which should flush logs |
| 803 | + mock_ipython.events.trigger( |
| 804 | + "post_run_cell", MockExecutionResult() |
| 805 | + ) |
| 806 | + |
| 807 | + gc.collect() |
| 808 | + |
| 809 | + assert mock_ipython.events.registers == 6 |
| 810 | + # TODO: figure out why the latest unregister is not called |
| 811 | + assert mock_ipython.events.unregisters == 4 |
| 812 | + assert len(mock_ipython.events.callbacks["post_run_cell"]) == 2 |
| 813 | + |
| 814 | + # Flush all outputs |
| 815 | + stdout_file.flush() |
| 816 | + os.fsync(stdout_file.fileno()) |
| 817 | + |
| 818 | + finally: |
| 819 | + # Restore Python's sys.stdout |
| 820 | + sys.stdout = original_sys_stdout |
| 821 | + |
| 822 | + # Restore original file descriptors |
| 823 | + os.dup2(original_stdout_fd, 1) |
| 824 | + |
| 825 | + # Read the captured output |
| 826 | + with open(stdout_path, "r") as f: |
| 827 | + stdout_content = f.read() |
| 828 | + |
| 829 | + # Clean up temp files |
| 830 | + os.unlink(stdout_path) |
| 831 | + |
| 832 | + # Verify that logs were flushed when the post_run_cell event was triggered |
| 833 | + # We should see the aggregated logs in the output |
| 834 | + assert ( |
| 835 | + len( |
| 836 | + re.findall( |
| 837 | + r"\[10 similar log lines\].*ipython1 test log", stdout_content |
| 838 | + ) |
| 839 | + ) |
| 840 | + == 3 |
| 841 | + ), stdout_content |
| 842 | + |
| 843 | + assert ( |
| 844 | + len( |
| 845 | + re.findall( |
| 846 | + r"\[10 similar log lines\].*ipython2 test log", stdout_content |
| 847 | + ) |
| 848 | + ) |
| 849 | + == 3 |
| 850 | + ), stdout_content |
| 851 | + |
| 852 | + finally: |
| 853 | + # Ensure file descriptors are restored even if something goes wrong |
| 854 | + try: |
| 855 | + os.dup2(original_stdout_fd, 1) |
| 856 | + os.close(original_stdout_fd) |
| 857 | + except OSError: |
| 858 | + pass |
| 859 | + |
| 860 | + |
721 | 861 | # oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
|
722 | 862 | @pytest.mark.oss_skip
|
723 | 863 | async def test_flush_logs_fast_exit() -> None:
|
@@ -849,7 +989,7 @@ async def test_multiple_ongoing_flushes_no_deadlock() -> None:
|
849 | 989 | for _ in range(10):
|
850 | 990 | await am.print.call("aggregated log line")
|
851 | 991 |
|
852 |
| - log_mesh = pm._logging_mesh_client |
| 992 | + log_mesh = pm._logging_manager._logging_mesh_client |
853 | 993 | assert log_mesh is not None
|
854 | 994 | futures = []
|
855 | 995 | for _ in range(5):
|
|
0 commit comments