Skip to content

Commit 7c40432

Browse files
shuhuayumeta-codesync[bot]
authored andcommitted
Enable get_or_spawn_controller from inside an actor endpoint (#1347)
Summary: Pull Request resolved: #1347 This diff enables get_or_spawn_controller API being called from inside an actor endpoint, and adds unit tests for this functionality. ghstack-source-id: 313701938 exported-using-ghexport Reviewed By: samlurye Differential Revision: D83109145 fbshipit-source-id: 19a8084982060877f827bd73ee8b39adacd5af77
1 parent 0d75ddd commit 7c40432

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

python/monarch/_src/actor/proc_mesh.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,12 @@ async def task() -> HyProcMesh:
343343
)._class
344344
is _ControllerController
345345
), "Expected v0 _ControllerController, got v1 _ControllerController"
346-
pm._controller_controller = instance._controller_controller # type: ignore
346+
if instance._controller_controller is None:
347+
pm._controller_controller = _get_controller_controller()[1]
348+
else:
349+
pm._controller_controller = cast(
350+
_ControllerController, instance._controller_controller
351+
)
347352
instance._add_child(pm)
348353

349354
async def task(

python/tests/test_python_actors.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@
4848
fake_in_process_host,
4949
HostMesh,
5050
)
51-
from monarch._src.actor.proc_mesh import _get_bootstrap_args, ProcMesh
51+
from monarch._src.actor.proc_mesh import (
52+
_get_bootstrap_args,
53+
get_or_spawn_controller,
54+
ProcMesh,
55+
)
5256
from monarch._src.actor.v1.host_mesh import (
5357
_bootstrap_cmd,
5458
fake_in_process_host as fake_in_process_host_v1,
@@ -1640,3 +1644,25 @@ def test_cuda_is_not_initialized_in_a_new_proc():
16401644
pytest.skip("cannot find cuda")
16411645
proc = this_host().spawn_procs().spawn("is_init", IsInit)
16421646
assert not proc.is_cuda_initialized.call_one().get()
1647+
1648+
1649+
class SpawningActorFromEndpointActor(Actor):
1650+
def __init__(self, root="None"):
1651+
self._root = root
1652+
1653+
@endpoint
1654+
def return_root(self):
1655+
return self._root
1656+
1657+
@endpoint
1658+
async def spawning_from_endpoint(self, name, root) -> None:
1659+
await get_or_spawn_controller(name, SpawningActorFromEndpointActor, root=root)
1660+
1661+
1662+
@pytest.mark.timeout(60)
1663+
def test_get_or_spawn_controller_inside_actor_endpoint():
1664+
actor_1 = get_or_spawn_controller("actor_1", SpawningActorFromEndpointActor).get()
1665+
actor_1.spawning_from_endpoint.call_one("actor_2", root="actor_1").get()
1666+
actor_2 = get_or_spawn_controller("actor_2", SpawningActorFromEndpointActor).get()
1667+
# verify that actor_2 was spawned from actor_1 with the correct root
1668+
assert actor_2.return_root.call_one().get() == "actor_1"

0 commit comments

Comments
 (0)