|
6 | 6 |
|
7 | 7 | # pyre-unsafe
|
8 | 8 | import asyncio
|
| 9 | +import gc |
9 | 10 | import importlib.resources
|
10 | 11 | import logging
|
11 | 12 | import operator
|
@@ -576,7 +577,7 @@ async def test_actor_log_streaming() -> None:
|
576 | 577 | await am.log.call("has log streaming as level matched")
|
577 | 578 |
|
578 | 579 | # TODO: remove this completely once we hook the flush logic upon dropping device_mesh
|
579 |
| - log_mesh = pm._logging_mesh_client |
| 580 | + log_mesh = pm._logging_manager._logging_mesh_client |
580 | 581 | assert log_mesh is not None
|
581 | 582 | Future(coro=log_mesh.flush().spawn().task()).get()
|
582 | 583 |
|
@@ -695,7 +696,7 @@ async def test_logging_option_defaults() -> None:
|
695 | 696 | await am.log.call("log streaming")
|
696 | 697 |
|
697 | 698 | # TODO: remove this completely once we hook the flush logic upon dropping device_mesh
|
698 |
| - log_mesh = pm._logging_mesh_client |
| 699 | + log_mesh = pm._logging_manager._logging_mesh_client |
699 | 700 | assert log_mesh is not None
|
700 | 701 | Future(coro=log_mesh.flush().spawn().task()).get()
|
701 | 702 |
|
@@ -750,6 +751,149 @@ async def test_logging_option_defaults() -> None:
|
750 | 751 | pass
|
751 | 752 |
|
752 | 753 |
|
| 754 | +# oss_skip: pytest keeps complaining about mocking get_ipython module |
| 755 | +@pytest.mark.oss_skip |
| 756 | +@pytest.mark.timeout(180) |
| 757 | +async def test_flush_logs_ipython() -> None: |
| 758 | + """Test that logs are flushed when get_ipython is available and post_run_cell event is triggered.""" |
| 759 | + # Save original file descriptors |
| 760 | + original_stdout_fd = os.dup(1) # stdout |
| 761 | + |
| 762 | + try: |
| 763 | + # Create temporary files to capture output |
| 764 | + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file: |
| 765 | + stdout_path = stdout_file.name |
| 766 | + |
| 767 | + # Redirect file descriptors to our temp files |
| 768 | + os.dup2(stdout_file.fileno(), 1) |
| 769 | + |
| 770 | + # Also redirect Python's sys.stdout |
| 771 | + original_sys_stdout = sys.stdout |
| 772 | + sys.stdout = stdout_file |
| 773 | + |
| 774 | + try: |
| 775 | + # Mock IPython environment |
| 776 | + class MockExecutionResult: |
| 777 | + pass |
| 778 | + |
| 779 | + class MockEvents: |
| 780 | + def __init__(self): |
| 781 | + self.callbacks = {} |
| 782 | + self.registers = 0 |
| 783 | + self.unregisters = 0 |
| 784 | + |
| 785 | + def register(self, event_name, callback): |
| 786 | + if event_name not in self.callbacks: |
| 787 | + self.callbacks[event_name] = [] |
| 788 | + self.callbacks[event_name].append(callback) |
| 789 | + self.registers += 1 |
| 790 | + |
| 791 | + def unregister(self, event_name, callback): |
| 792 | + if event_name not in self.callbacks: |
| 793 | + raise ValueError(f"Event {event_name} not registered") |
| 794 | + assert callback in self.callbacks[event_name] |
| 795 | + self.callbacks[event_name].remove(callback) |
| 796 | + self.unregisters += 1 |
| 797 | + |
| 798 | + def trigger(self, event_name, *args, **kwargs): |
| 799 | + if event_name in self.callbacks: |
| 800 | + for callback in self.callbacks[event_name]: |
| 801 | + callback(*args, **kwargs) |
| 802 | + |
| 803 | + class MockIPython: |
| 804 | + def __init__(self): |
| 805 | + self.events = MockEvents() |
| 806 | + |
| 807 | + mock_ipython = MockIPython() |
| 808 | + |
| 809 | + with unittest.mock.patch( |
| 810 | + "monarch._src.actor.logging.get_ipython", |
| 811 | + lambda: mock_ipython, |
| 812 | + ), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True): |
| 813 | + # Make sure we can register and unregister callbacks |
| 814 | + for i in range(3): |
| 815 | + pm1 = await proc_mesh(gpus=2) |
| 816 | + pm2 = await proc_mesh(gpus=2) |
| 817 | + am1 = await pm1.spawn("printer", Printer) |
| 818 | + am2 = await pm2.spawn("printer", Printer) |
| 819 | + |
| 820 | + # Set aggregation window to ensure logs are buffered |
| 821 | + await pm1.logging_option( |
| 822 | + stream_to_client=True, aggregate_window_sec=600 |
| 823 | + ) |
| 824 | + await pm2.logging_option( |
| 825 | + stream_to_client=True, aggregate_window_sec=600 |
| 826 | + ) |
| 827 | + assert mock_ipython.events.unregisters == 2 * i |
| 828 | + assert mock_ipython.events.registers == 1 + 2 * (i + 1) |
| 829 | + await asyncio.sleep(1) |
| 830 | + |
| 831 | + # Generate some logs that will be aggregated |
| 832 | + for _ in range(5): |
| 833 | + await am1.print.call("ipython1 test log") |
| 834 | + await am2.print.call("ipython2 test log") |
| 835 | + |
| 836 | + # Trigger the post_run_cell event which should flush logs |
| 837 | + mock_ipython.events.trigger( |
| 838 | + "post_run_cell", MockExecutionResult() |
| 839 | + ) |
| 840 | + |
| 841 | + # Flush all outputs |
| 842 | + stdout_file.flush() |
| 843 | + os.fsync(stdout_file.fileno()) |
| 844 | + |
| 845 | + gc.collect() |
| 846 | + |
| 847 | + assert mock_ipython.events.registers == 7 |
| 848 | + # The last two creation somehow cannot get ref dropped |
| 849 | + assert mock_ipython.events.unregisters == 4 |
| 850 | + assert len(mock_ipython.events.callbacks["post_run_cell"]) == 3 |
| 851 | + |
| 852 | + finally: |
| 853 | + # Restore Python's sys.stdout |
| 854 | + sys.stdout = original_sys_stdout |
| 855 | + |
| 856 | + # Restore original file descriptors |
| 857 | + os.dup2(original_stdout_fd, 1) |
| 858 | + |
| 859 | + # Read the captured output |
| 860 | + with open(stdout_path, "r") as f: |
| 861 | + stdout_content = f.read() |
| 862 | + |
| 863 | + # TODO: there are quite a lot of code dups and boilerplate; make them contextmanager utils |
| 864 | + |
| 865 | + # Clean up temp files |
| 866 | + os.unlink(stdout_path) |
| 867 | + |
| 868 | + # Verify that logs were flushed when the post_run_cell event was triggered |
| 869 | + # We should see the aggregated logs in the output |
| 870 | + assert ( |
| 871 | + len( |
| 872 | + re.findall( |
| 873 | + r"\[10 similar log lines\].*ipython1 test log", stdout_content |
| 874 | + ) |
| 875 | + ) |
| 876 | + == 3 |
| 877 | + ), stdout_content |
| 878 | + |
| 879 | + assert ( |
| 880 | + len( |
| 881 | + re.findall( |
| 882 | + r"\[10 similar log lines\].*ipython2 test log", stdout_content |
| 883 | + ) |
| 884 | + ) |
| 885 | + == 3 |
| 886 | + ), stdout_content |
| 887 | + |
| 888 | + finally: |
| 889 | + # Ensure file descriptors are restored even if something goes wrong |
| 890 | + try: |
| 891 | + os.dup2(original_stdout_fd, 1) |
| 892 | + os.close(original_stdout_fd) |
| 893 | + except OSError: |
| 894 | + pass |
| 895 | + |
| 896 | + |
753 | 897 | # oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
|
754 | 898 | @pytest.mark.oss_skip
|
755 | 899 | async def test_flush_logs_fast_exit() -> None:
|
@@ -824,7 +968,7 @@ async def test_flush_on_disable_aggregation() -> None:
|
824 | 968 | await am.print.call("single log line")
|
825 | 969 |
|
826 | 970 | # TODO: remove this completely once we hook the flush logic upon dropping device_mesh
|
827 |
| - log_mesh = pm._logging_mesh_client |
| 971 | + log_mesh = pm._logging_manager._logging_mesh_client |
828 | 972 | assert log_mesh is not None
|
829 | 973 | Future(coro=log_mesh.flush().spawn().task()).get()
|
830 | 974 |
|
@@ -884,7 +1028,7 @@ async def test_multiple_ongoing_flushes_no_deadlock() -> None:
|
884 | 1028 | for _ in range(10):
|
885 | 1029 | await am.print.call("aggregated log line")
|
886 | 1030 |
|
887 |
| - log_mesh = pm._logging_mesh_client |
| 1031 | + log_mesh = pm._logging_manager._logging_mesh_client |
888 | 1032 | assert log_mesh is not None
|
889 | 1033 | futures = []
|
890 | 1034 | for _ in range(5):
|
@@ -937,7 +1081,7 @@ async def test_adjust_aggregation_window() -> None:
|
937 | 1081 | await am.print.call("second batch of logs")
|
938 | 1082 |
|
939 | 1083 | # TODO: remove this completely once we hook the flush logic upon dropping device_mesh
|
940 |
| - log_mesh = pm._logging_mesh_client |
| 1084 | + log_mesh = pm._logging_manager._logging_mesh_client |
941 | 1085 | assert log_mesh is not None
|
942 | 1086 | Future(coro=log_mesh.flush().spawn().task()).get()
|
943 | 1087 |
|
|
0 commit comments