11
11
12
12
import click
13
13
from monarch ._rust_bindings .monarch_extension .blocking import blocking_function
14
-
15
14
from monarch ._rust_bindings .monarch_extension .panic import panicking_function
15
+ from monarch ._src .actor .proc_mesh import ProcMesh
16
+ from monarch ._src .actor .v1 .host_mesh import this_host as this_host_v1
17
+ from monarch ._src .actor .v1 .proc_mesh import ProcMesh as ProcMeshV1
18
+
19
+ from monarch .actor import Actor , endpoint , send , this_host
16
20
17
- from monarch .actor import Actor , endpoint , proc_mesh , send
21
+
22
+ def spawn_procs_on_this_host (
23
+ v1 : bool , per_host : dict [str , int ]
24
+ ) -> ProcMesh | ProcMeshV1 :
25
+ if v1 :
26
+ return this_host_v1 ().spawn_procs (name = "proc" , per_host = per_host )
27
+ else :
28
+ return this_host ().spawn_procs (per_host )
18
29
19
30
20
31
class ErrorActor (Actor ):
@@ -77,13 +88,13 @@ def cause_panic(self) -> None:
77
88
panicking_function ()
78
89
79
90
80
- def _run_error_test_sync (num_procs , sync_endpoint , endpoint_name ):
81
- proc = proc_mesh ( gpus = num_procs ). get ( )
91
+ def _run_error_test_sync (num_procs , sync_endpoint , endpoint_name , v1 ):
92
+ proc = spawn_procs_on_this_host ( v1 , { " gpus" : num_procs } )
82
93
if sync_endpoint :
83
94
actor_class = ErrorActorSync
84
95
else :
85
96
actor_class = ErrorActor
86
- error_actor = proc .spawn ("error_actor" , actor_class ). get ()
97
+ error_actor = proc .spawn ("error_actor" , actor_class )
87
98
88
99
# This output is checked in the test to make sure that the process actually got here
89
100
print ("Started function error_test" , flush = True )
@@ -103,7 +114,7 @@ def _run_error_test_sync(num_procs, sync_endpoint, endpoint_name):
103
114
endpoint .call ().get ()
104
115
105
116
106
- def _run_error_test (num_procs , sync_endpoint , endpoint_name ):
117
+ def _run_error_test (num_procs , sync_endpoint , endpoint_name , v1 ):
107
118
import asyncio
108
119
109
120
if sync_endpoint :
@@ -112,7 +123,7 @@ def _run_error_test(num_procs, sync_endpoint, endpoint_name):
112
123
actor_class = ErrorActor
113
124
114
125
async def run_test ():
115
- proc = proc_mesh ( gpus = num_procs )
126
+ proc = spawn_procs_on_this_host ( v1 , per_host = { "gpus" : num_procs } )
116
127
error_actor = proc .spawn ("error_actor" , actor_class )
117
128
118
129
# This output is checked in the test to make sure that the process actually got here
@@ -145,29 +156,31 @@ def main():
145
156
@click .option ("--sync-test-impl" , type = bool , required = True )
146
157
@click .option ("--sync-endpoint" , type = bool , required = True )
147
158
@click .option ("--endpoint-name" , type = str , required = True )
148
- def error_endpoint (num_procs , sync_test_impl , sync_endpoint , endpoint_name ):
159
+ @click .option ("--v1" , type = bool , required = True )
160
+ def error_endpoint (num_procs , sync_test_impl , sync_endpoint , endpoint_name , v1 ):
149
161
print (
150
162
f"Running segfault test: { num_procs = } { sync_test_impl = } { sync_endpoint = } , { endpoint_name = } "
151
163
)
152
164
153
165
if sync_test_impl :
154
- _run_error_test_sync (num_procs , sync_endpoint , endpoint_name )
166
+ _run_error_test_sync (num_procs , sync_endpoint , endpoint_name , v1 )
155
167
else :
156
- _run_error_test (num_procs , sync_endpoint , endpoint_name )
168
+ _run_error_test (num_procs , sync_endpoint , endpoint_name , v1 )
157
169
158
170
159
171
@main .command ("error-bootstrap" )
160
- def error_bootstrap ():
172
+ @click .option ("--v1" , type = bool , required = True )
173
+ def error_bootstrap (v1 ):
161
174
print ("Started function error_bootstrap" , flush = True )
162
- proc_mesh (
163
- gpus = 4 , env = {"MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING" : "1" }
175
+ spawn_procs_on_this_host (
176
+ v1 , { " gpus" : 4 } , env = {"MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING" : "1" }
164
177
).initialized .get ()
165
178
166
179
167
- async def _error_unmonitored ():
180
+ async def _error_unmonitored (v1 ):
168
181
print ("Started function _error_unmonitored" , flush = True )
169
182
170
- proc = proc_mesh ( gpus = 1 )
183
+ proc = spawn_procs_on_this_host ( v1 , { " gpus" : 1 } )
171
184
actor = proc .spawn ("error_actor" , ErrorActor )
172
185
173
186
# fire and forget
@@ -183,11 +196,11 @@ async def _error_unmonitored():
183
196
184
197
"""
185
198
TODO: This test should be enabled when stop() is fully implemented.
186
- async def _error_unmonitored():
199
+ async def _error_unmonitored(v1 ):
187
200
print("I actually ran")
188
201
sys.stdout.flush()
189
202
190
- proc = proc_mesh( gpus=1 )
203
+ proc = spawn_procs_on_this_host(v1, {" gpus": 1} )
191
204
actor = proc.spawn("error_actor", ErrorActor)
192
205
193
206
# fire and forget
@@ -204,16 +217,17 @@ async def _error_unmonitored():
204
217
205
218
206
219
@main .command ("error-unmonitored" )
207
- def error_unmonitored ():
208
- asyncio .run (_error_unmonitored ())
220
+ @click .option ("--v1" , type = bool , required = True )
221
+ def error_unmonitored (v1 ):
222
+ asyncio .run (_error_unmonitored (v1 ))
209
223
210
224
211
- async def _error_cleanup ():
225
+ async def _error_cleanup (v1 ):
212
226
"""Test function that spawns an 8 process procmesh and calls an endpoint that returns a normal exception."""
213
227
print ("Started function _error_cleanup() for parent process" , flush = True )
214
228
215
229
# Spawn an 8 process procmesh
216
- proc = proc_mesh ( gpus = 8 )
230
+ proc = spawn_procs_on_this_host ( v1 , { " gpus" : 8 } )
217
231
error_actor = proc .spawn ("error_actor" , ErrorActor )
218
232
219
233
print ("Procmesh spawned, collecting child PIDs from actors" , flush = True )
@@ -239,9 +253,10 @@ async def _error_cleanup():
239
253
240
254
241
255
@main .command ("error-cleanup" )
242
- def error_cleanup ():
256
+ @click .option ("--v1" , type = bool , required = True )
257
+ def error_cleanup (v1 ):
243
258
"""Command that spawns an 8 process procmesh and calls an endpoint that returns a normal exception."""
244
- asyncio .run (_error_cleanup ())
259
+ asyncio .run (_error_cleanup (v1 ))
245
260
246
261
247
262
if __name__ == "__main__" :
0 commit comments