37
37
)
38
38
from monarch ._rust_bindings .monarch_hyperactor .proc import ActorId
39
39
from monarch ._rust_bindings .monarch_hyperactor .pytokio import PythonTask
40
+ from monarch ._rust_bindings .monarch_hyperactor .shape import Extent
40
41
41
42
from monarch ._src .actor .actor_mesh import ActorMesh , Channel , context , Port
42
- from monarch ._src .actor .allocator import AllocHandle
43
+ from monarch ._src .actor .allocator import AllocHandle , ProcessAllocator
43
44
from monarch ._src .actor .future import Future
44
45
from monarch ._src .actor .host_mesh import (
45
46
create_local_host_mesh ,
46
47
fake_in_process_host ,
47
48
HostMesh ,
48
49
)
49
- from monarch ._src .actor .proc_mesh import ProcMesh
50
+ from monarch ._src .actor .proc_mesh import _get_bootstrap_args , ProcMesh
50
51
from monarch ._src .actor .v1 .host_mesh import (
52
+ _bootstrap_cmd ,
51
53
fake_in_process_host as fake_in_process_host_v1 ,
52
54
HostMesh as HostMeshV1 ,
53
55
this_host as this_host_v1 ,
@@ -466,7 +468,7 @@ async def no_more(self) -> None:
466
468
467
469
468
470
@pytest .mark .parametrize ("v1" , [True , False ])
469
- @pytest .mark .timeout (30 )
471
+ @pytest .mark .timeout (60 )
470
472
async def test_async_concurrency (v1 : bool ):
471
473
"""Test that async endpoints will be processed concurrently."""
472
474
pm = spawn_procs_on_this_host (v1 , {})
@@ -603,8 +605,9 @@ def _handle_undeliverable_message(
603
605
return True
604
606
605
607
608
+ @pytest .mark .parametrize ("v1" , [True , False ])
606
609
@pytest .mark .timeout (60 )
607
- async def test_actor_log_streaming () -> None :
610
+ async def test_actor_log_streaming (v1 : bool ) -> None :
608
611
# Save original file descriptors
609
612
original_stdout_fd = os .dup (1 ) # stdout
610
613
original_stderr_fd = os .dup (2 ) # stderr
@@ -631,7 +634,7 @@ async def test_actor_log_streaming() -> None:
631
634
sys .stderr = stderr_file
632
635
633
636
try :
634
- pm = spawn_procs_on_this_host (v1 = False , per_host = {"gpus" : 2 })
637
+ pm = spawn_procs_on_this_host (v1 , per_host = {"gpus" : 2 })
635
638
am = pm .spawn ("printer" , Printer )
636
639
637
640
# Disable streaming logs to client
@@ -671,7 +674,10 @@ async def test_actor_log_streaming() -> None:
671
674
await am .print .call ("has print streaming too" )
672
675
await am .log .call ("has log streaming as level matched" )
673
676
674
- await pm .stop ()
677
+ if not v1 :
678
+ await pm .stop ()
679
+ else :
680
+ await asyncio .sleep (1 )
675
681
676
682
# Flush all outputs
677
683
stdout_file .flush ()
@@ -752,8 +758,9 @@ async def test_actor_log_streaming() -> None:
752
758
pass
753
759
754
760
761
+ @pytest .mark .parametrize ("v1" , [True , False ])
755
762
@pytest .mark .timeout (120 )
756
- async def test_alloc_based_log_streaming () -> None :
763
+ async def test_alloc_based_log_streaming (v1 : bool ) -> None :
757
764
"""Test both AllocHandle.stream_logs = False and True cases."""
758
765
759
766
async def test_stream_logs_case (stream_logs : bool , test_name : str ) -> None :
@@ -770,23 +777,45 @@ async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None:
770
777
771
778
try :
772
779
# Create proc mesh with custom stream_logs setting
773
- host_mesh = create_local_host_mesh ()
774
- alloc_handle = host_mesh ._alloc (hosts = 1 , gpus = 2 )
780
+ if not v1 :
781
+ host_mesh = create_local_host_mesh ()
782
+ alloc_handle = host_mesh ._alloc (hosts = 1 , gpus = 2 )
783
+
784
+ # Override the stream_logs setting
785
+ custom_alloc_handle = AllocHandle (
786
+ alloc_handle ._hy_alloc , alloc_handle ._extent , stream_logs
787
+ )
788
+
789
+ pm = ProcMesh .from_alloc (custom_alloc_handle )
790
+ else :
775
791
776
- # Override the stream_logs setting
777
- custom_alloc_handle = AllocHandle (
778
- alloc_handle ._hy_alloc , alloc_handle ._extent , stream_logs
779
- )
792
+ class ProcessAllocatorStreamLogs (ProcessAllocator ):
793
+ def _stream_logs (self ) -> bool :
794
+ return stream_logs
795
+
796
+ alloc = ProcessAllocatorStreamLogs (* _get_bootstrap_args ())
797
+
798
+ host_mesh = HostMeshV1 .allocate_nonblocking (
799
+ "host" ,
800
+ Extent (["hosts" ], [1 ]),
801
+ alloc ,
802
+ bootstrap_cmd = _bootstrap_cmd (),
803
+ )
804
+
805
+ pm = host_mesh .spawn_procs (name = "proc" , per_host = {"gpus" : 2 })
780
806
781
- pm = ProcMesh .from_alloc (custom_alloc_handle )
782
807
am = pm .spawn ("printer" , Printer )
783
808
784
809
await pm .initialized
785
810
786
811
for _ in range (5 ):
787
812
await am .print .call (f"{ test_name } print streaming" )
788
813
789
- await pm .stop ()
814
+ if not v1 :
815
+ await pm .stop ()
816
+ else :
817
+ # Wait for at least the aggregation window (3 seconds)
818
+ await asyncio .sleep (5 )
790
819
791
820
# Flush all outputs
792
821
stdout_file .flush ()
@@ -810,18 +839,18 @@ async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None:
810
839
# When stream_logs=False, logs should not be streamed to client
811
840
assert not re .search (
812
841
rf"similar log lines.*{ test_name } print streaming" , stdout_content
813
- ), f"stream_logs=True case: { stdout_content } "
842
+ ), f"stream_logs=False case: { stdout_content } "
814
843
assert re .search (
815
844
rf"{ test_name } print streaming" , stdout_content
816
- ), f"stream_logs=True case: { stdout_content } "
845
+ ), f"stream_logs=False case: { stdout_content } "
817
846
else :
818
847
# When stream_logs=True, logs should be streamed to client (no aggregation by default)
819
848
assert re .search (
820
849
rf"similar log lines.*{ test_name } print streaming" , stdout_content
821
- ), f"stream_logs=False case: { stdout_content } "
850
+ ), f"stream_logs=True case: { stdout_content } "
822
851
assert not re .search (
823
852
rf"\[[0-9]\]{ test_name } print streaming" , stdout_content
824
- ), f"stream_logs=False case: { stdout_content } "
853
+ ), f"stream_logs=True case: { stdout_content } "
825
854
826
855
finally :
827
856
# Ensure file descriptors are restored even if something goes wrong
@@ -836,8 +865,9 @@ async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None:
836
865
await test_stream_logs_case (True , "stream_logs_true" )
837
866
838
867
868
+ @pytest .mark .parametrize ("v1" , [True , False ])
839
869
@pytest .mark .timeout (60 )
840
- async def test_logging_option_defaults () -> None :
870
+ async def test_logging_option_defaults (v1 : bool ) -> None :
841
871
# Save original file descriptors
842
872
original_stdout_fd = os .dup (1 ) # stdout
843
873
original_stderr_fd = os .dup (2 ) # stderr
@@ -864,14 +894,18 @@ async def test_logging_option_defaults() -> None:
864
894
sys .stderr = stderr_file
865
895
866
896
try :
867
- pm = spawn_procs_on_this_host (v1 = False , per_host = {"gpus" : 2 })
897
+ pm = spawn_procs_on_this_host (v1 , per_host = {"gpus" : 2 })
868
898
am = pm .spawn ("printer" , Printer )
869
899
870
900
for _ in range (5 ):
871
901
await am .print .call ("print streaming" )
872
902
await am .log .call ("log streaming" )
873
903
874
- await pm .stop ()
904
+ if not v1 :
905
+ await pm .stop ()
906
+ else :
907
+ # Wait for > default aggregation window (3 seconds)
908
+ await asyncio .sleep (5 )
875
909
876
910
# Flush all outputs
877
911
stdout_file .flush ()
@@ -949,7 +983,8 @@ def __init__(self):
949
983
950
984
# oss_skip: pytest keeps complaining about mocking get_ipython module
951
985
@pytest .mark .oss_skip
952
- async def test_flush_called_only_once () -> None :
986
+ @pytest .mark .parametrize ("v1" , [True , False ])
987
+ async def test_flush_called_only_once (v1 : bool ) -> None :
953
988
"""Test that flush is called only once when ending an ipython cell"""
954
989
mock_ipython = MockIPython ()
955
990
with unittest .mock .patch (
@@ -961,8 +996,8 @@ async def test_flush_called_only_once() -> None:
961
996
"monarch._src.actor.logging.flush_all_proc_mesh_logs"
962
997
) as mock_flush :
963
998
# Create 2 proc meshes with a large aggregation window
964
- pm1 = this_host (). spawn_procs ( per_host = {"gpus" : 2 })
965
- _ = this_host (). spawn_procs ( per_host = {"gpus" : 2 })
999
+ pm1 = spawn_procs_on_this_host ( v1 , per_host = {"gpus" : 2 })
1000
+ _ = spawn_procs_on_this_host ( v1 , per_host = {"gpus" : 2 })
966
1001
# flush not yet called unless post_run_cell
967
1002
assert mock_flush .call_count == 0
968
1003
assert mock_ipython .events .registers == 0
@@ -976,8 +1011,9 @@ async def test_flush_called_only_once() -> None:
976
1011
977
1012
# oss_skip: pytest keeps complaining about mocking get_ipython module
978
1013
@pytest .mark .oss_skip
1014
+ @pytest .mark .parametrize ("v1" , [True , False ])
979
1015
@pytest .mark .timeout (180 )
980
- async def test_flush_logs_ipython () -> None :
1016
+ async def test_flush_logs_ipython (v1 : bool ) -> None :
981
1017
"""Test that logs are flushed when get_ipython is available and post_run_cell event is triggered."""
982
1018
# Save original file descriptors
983
1019
original_stdout_fd = os .dup (1 ) # stdout
@@ -1003,8 +1039,8 @@ async def test_flush_logs_ipython() -> None:
1003
1039
), unittest .mock .patch ("monarch._src.actor.logging.IN_IPYTHON" , True ):
1004
1040
# Make sure we can register and unregister callbacks
1005
1041
for _ in range (3 ):
1006
- pm1 = this_host (). spawn_procs ( per_host = {"gpus" : 2 })
1007
- pm2 = this_host (). spawn_procs ( per_host = {"gpus" : 2 })
1042
+ pm1 = spawn_procs_on_this_host ( v1 , per_host = {"gpus" : 2 })
1043
+ pm2 = spawn_procs_on_this_host ( v1 , per_host = {"gpus" : 2 })
1008
1044
am1 = pm1 .spawn ("printer" , Printer )
1009
1045
am2 = pm2 .spawn ("printer" , Printer )
1010
1046
@@ -1108,8 +1144,9 @@ async def test_flush_logs_fast_exit() -> None:
1108
1144
), process .stdout
1109
1145
1110
1146
1147
+ @pytest .mark .parametrize ("v1" , [True , False ])
1111
1148
@pytest .mark .timeout (60 )
1112
- async def test_flush_on_disable_aggregation () -> None :
1149
+ async def test_flush_on_disable_aggregation (v1 : bool ) -> None :
1113
1150
"""Test that logs are flushed when disabling aggregation.
1114
1151
1115
1152
This tests the corner case: "Make sure we flush whatever in the aggregators before disabling aggregation."
@@ -1130,7 +1167,7 @@ async def test_flush_on_disable_aggregation() -> None:
1130
1167
sys .stdout = stdout_file
1131
1168
1132
1169
try :
1133
- pm = this_host (). spawn_procs ( per_host = {"gpus" : 2 })
1170
+ pm = spawn_procs_on_this_host ( v1 , per_host = {"gpus" : 2 })
1134
1171
am = pm .spawn ("printer" , Printer )
1135
1172
1136
1173
# Set a long aggregation window to ensure logs aren't flushed immediately
@@ -1151,7 +1188,11 @@ async def test_flush_on_disable_aggregation() -> None:
1151
1188
for _ in range (5 ):
1152
1189
await am .print .call ("single log line" )
1153
1190
1154
- await pm .stop ()
1191
+ if not v1 :
1192
+ await pm .stop ()
1193
+ else :
1194
+ # Wait for > default aggregation window (3 secs)
1195
+ await asyncio .sleep (5 )
1155
1196
1156
1197
# Flush all outputs
1157
1198
stdout_file .flush ()
@@ -1197,14 +1238,15 @@ async def test_flush_on_disable_aggregation() -> None:
1197
1238
pass
1198
1239
1199
1240
1241
+ @pytest .mark .parametrize ("v1" , [True , False ])
1200
1242
@pytest .mark .timeout (120 )
1201
- async def test_multiple_ongoing_flushes_no_deadlock () -> None :
1243
+ async def test_multiple_ongoing_flushes_no_deadlock (v1 : bool ) -> None :
1202
1244
"""
1203
1245
The goal is to make sure when a user sends multiple sync flushes, we are not deadlocked.
1204
1246
Because now a flush call is purely sync, it is very easy to get into a deadlock.
1205
1247
So we assert the last flush call will not get into such a state.
1206
1248
"""
1207
- pm = this_host (). spawn_procs ( per_host = {"gpus" : 4 })
1249
+ pm = spawn_procs_on_this_host ( v1 , per_host = {"gpus" : 4 })
1208
1250
am = pm .spawn ("printer" , Printer )
1209
1251
1210
1252
# Generate some logs that will be aggregated but not flushed immediately
@@ -1227,8 +1269,9 @@ async def test_multiple_ongoing_flushes_no_deadlock() -> None:
1227
1269
futures [- 1 ].get ()
1228
1270
1229
1271
1272
+ @pytest .mark .parametrize ("v1" , [True , False ])
1230
1273
@pytest .mark .timeout (60 )
1231
- async def test_adjust_aggregation_window () -> None :
1274
+ async def test_adjust_aggregation_window (v1 : bool ) -> None :
1232
1275
"""Test that the flush deadline is updated when the aggregation window is adjusted.
1233
1276
1234
1277
This tests the corner case: "This can happen if the user has adjusted the aggregation window."
@@ -1249,7 +1292,7 @@ async def test_adjust_aggregation_window() -> None:
1249
1292
sys .stdout = stdout_file
1250
1293
1251
1294
try :
1252
- pm = this_host (). spawn_procs ( per_host = {"gpus" : 2 })
1295
+ pm = spawn_procs_on_this_host ( v1 , per_host = {"gpus" : 2 })
1253
1296
am = pm .spawn ("printer" , Printer )
1254
1297
1255
1298
# Set a long aggregation window initially
@@ -1267,7 +1310,11 @@ async def test_adjust_aggregation_window() -> None:
1267
1310
for _ in range (3 ):
1268
1311
await am .print .call ("second batch of logs" )
1269
1312
1270
- await pm .stop ()
1313
+ if not v1 :
1314
+ await pm .stop ()
1315
+ else :
1316
+ # Wait for > aggregation window (2 secs)
1317
+ await asyncio .sleep (4 )
1271
1318
1272
1319
# Flush all outputs
1273
1320
stdout_file .flush ()
0 commit comments