Skip to content

Commit 7601419

Browse files
dulinrileymeta-codesync[bot]
authored andcommitted
Parameterize test_actor_error on v1 (#1385)
Summary: Pull Request resolved: #1385 Part of: #1209 Parameterize the actor error tests on using the v1 API similar to how test_python_actors.py was done in #1346. 11 out of the 18 test entrypoints don't work with v1 yet due to the missing supervision APIs, which are ongoing. This sets up a framework for testing them all. Reviewed By: colin2328 Differential Revision: D83591768 fbshipit-source-id: 78ebb195ad90a01f1373f1a046d2d981ed7d06e3
1 parent 7c40432 commit 7601419

File tree

2 files changed

+172
-80
lines changed

2 files changed

+172
-80
lines changed

python/tests/error_test_binary.py

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,21 @@
1111

1212
import click
1313
from monarch._rust_bindings.monarch_extension.blocking import blocking_function
14-
1514
from monarch._rust_bindings.monarch_extension.panic import panicking_function
15+
from monarch._src.actor.proc_mesh import ProcMesh
16+
from monarch._src.actor.v1.host_mesh import this_host as this_host_v1
17+
from monarch._src.actor.v1.proc_mesh import ProcMesh as ProcMeshV1
18+
19+
from monarch.actor import Actor, endpoint, send, this_host
1620

17-
from monarch.actor import Actor, endpoint, proc_mesh, send
21+
22+
def spawn_procs_on_this_host(
23+
v1: bool, per_host: dict[str, int]
24+
) -> ProcMesh | ProcMeshV1:
25+
if v1:
26+
return this_host_v1().spawn_procs(name="proc", per_host=per_host)
27+
else:
28+
return this_host().spawn_procs(per_host)
1829

1930

2031
class ErrorActor(Actor):
@@ -77,13 +88,13 @@ def cause_panic(self) -> None:
7788
panicking_function()
7889

7990

80-
def _run_error_test_sync(num_procs, sync_endpoint, endpoint_name):
81-
proc = proc_mesh(gpus=num_procs).get()
91+
def _run_error_test_sync(num_procs, sync_endpoint, endpoint_name, v1):
92+
proc = spawn_procs_on_this_host(v1, {"gpus": num_procs})
8293
if sync_endpoint:
8394
actor_class = ErrorActorSync
8495
else:
8596
actor_class = ErrorActor
86-
error_actor = proc.spawn("error_actor", actor_class).get()
97+
error_actor = proc.spawn("error_actor", actor_class)
8798

8899
# This output is checked in the test to make sure that the process actually got here
89100
print("Started function error_test", flush=True)
@@ -103,7 +114,7 @@ def _run_error_test_sync(num_procs, sync_endpoint, endpoint_name):
103114
endpoint.call().get()
104115

105116

106-
def _run_error_test(num_procs, sync_endpoint, endpoint_name):
117+
def _run_error_test(num_procs, sync_endpoint, endpoint_name, v1):
107118
import asyncio
108119

109120
if sync_endpoint:
@@ -112,7 +123,7 @@ def _run_error_test(num_procs, sync_endpoint, endpoint_name):
112123
actor_class = ErrorActor
113124

114125
async def run_test():
115-
proc = proc_mesh(gpus=num_procs)
126+
proc = spawn_procs_on_this_host(v1, per_host={"gpus": num_procs})
116127
error_actor = proc.spawn("error_actor", actor_class)
117128

118129
# This output is checked in the test to make sure that the process actually got here
@@ -145,29 +156,31 @@ def main():
145156
@click.option("--sync-test-impl", type=bool, required=True)
146157
@click.option("--sync-endpoint", type=bool, required=True)
147158
@click.option("--endpoint-name", type=str, required=True)
148-
def error_endpoint(num_procs, sync_test_impl, sync_endpoint, endpoint_name):
159+
@click.option("--v1", type=bool, required=True)
160+
def error_endpoint(num_procs, sync_test_impl, sync_endpoint, endpoint_name, v1):
149161
print(
150162
f"Running segfault test: {num_procs=} {sync_test_impl=} {sync_endpoint=}, {endpoint_name=}"
151163
)
152164

153165
if sync_test_impl:
154-
_run_error_test_sync(num_procs, sync_endpoint, endpoint_name)
166+
_run_error_test_sync(num_procs, sync_endpoint, endpoint_name, v1)
155167
else:
156-
_run_error_test(num_procs, sync_endpoint, endpoint_name)
168+
_run_error_test(num_procs, sync_endpoint, endpoint_name, v1)
157169

158170

159171
@main.command("error-bootstrap")
160-
def error_bootstrap():
172+
@click.option("--v1", type=bool, required=True)
173+
def error_bootstrap(v1):
161174
print("Started function error_bootstrap", flush=True)
162-
proc_mesh(
163-
gpus=4, env={"MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING": "1"}
175+
spawn_procs_on_this_host(
176+
v1, {"gpus": 4}, env={"MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING": "1"}
164177
).initialized.get()
165178

166179

167-
async def _error_unmonitored():
180+
async def _error_unmonitored(v1):
168181
print("Started function _error_unmonitored", flush=True)
169182

170-
proc = proc_mesh(gpus=1)
183+
proc = spawn_procs_on_this_host(v1, {"gpus": 1})
171184
actor = proc.spawn("error_actor", ErrorActor)
172185

173186
# fire and forget
@@ -183,11 +196,11 @@ async def _error_unmonitored():
183196

184197
"""
185198
TODO: This test should be enabled when stop() is fully implemented.
186-
async def _error_unmonitored():
199+
async def _error_unmonitored(v1):
187200
print("I actually ran")
188201
sys.stdout.flush()
189202
190-
proc = proc_mesh(gpus=1)
203+
proc = spawn_procs_on_this_host(v1, {"gpus": 1})
191204
actor = proc.spawn("error_actor", ErrorActor)
192205
193206
# fire and forget
@@ -204,16 +217,17 @@ async def _error_unmonitored():
204217

205218

206219
@main.command("error-unmonitored")
207-
def error_unmonitored():
208-
asyncio.run(_error_unmonitored())
220+
@click.option("--v1", type=bool, required=True)
221+
def error_unmonitored(v1):
222+
asyncio.run(_error_unmonitored(v1))
209223

210224

211-
async def _error_cleanup():
225+
async def _error_cleanup(v1):
212226
"""Test function that spawns an 8 process procmesh and calls an endpoint that returns a normal exception."""
213227
print("Started function _error_cleanup() for parent process", flush=True)
214228

215229
# Spawn an 8 process procmesh
216-
proc = proc_mesh(gpus=8)
230+
proc = spawn_procs_on_this_host(v1, {"gpus": 8})
217231
error_actor = proc.spawn("error_actor", ErrorActor)
218232

219233
print("Procmesh spawned, collecting child PIDs from actors", flush=True)
@@ -239,9 +253,10 @@ async def _error_cleanup():
239253

240254

241255
@main.command("error-cleanup")
242-
def error_cleanup():
256+
@click.option("--v1", type=bool, required=True)
257+
def error_cleanup(v1):
243258
"""Command that spawns an 8 process procmesh and calls an endpoint that returns a normal exception."""
244-
asyncio.run(_error_cleanup())
259+
asyncio.run(_error_cleanup(v1))
245260

246261

247262
if __name__ == "__main__":

0 commit comments

Comments
 (0)