-
Notifications
You must be signed in to change notification settings - Fork 16
Auto-track and globally shut down all Forge actors and services #357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
cc5fb5a
9cc5531
84dcd17
8ba3e5f
5b8d81c
5fe99d5
36225af
d705df9
52fa867
f0ba99a
4e477b5
e7ae73d
eb9f667
77c2344
b39fc4b
d4f3d57
0db9ed6
f31aa3a
ffbf0ca
7c5417d
5dc138c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,12 @@ | |
|
||
from monarch.actor import Actor, current_rank, current_size, endpoint | ||
|
||
from forge.controller.provisioner import get_proc_mesh, stop_proc_mesh | ||
from forge.controller.provisioner import ( | ||
get_proc_mesh, | ||
register_actor, | ||
register_service, | ||
stop_proc_mesh, | ||
) | ||
|
||
from forge.types import ProcessConfig, ServiceConfig | ||
|
||
|
@@ -81,11 +86,11 @@ def options( | |
|
||
# Pre-configure a single actor | ||
actor = await MyForgeActor.options(procs=1, hosts=1).as_actor(...) | ||
await actor.shutdown() | ||
await MyForgeActor.shutdown(actor) | ||
|
||
# Default usage without calling options | ||
actor = await MyForgeActor.as_actor(...) | ||
await actor.shutdown() | ||
await MyForgeActor.shutdown(actor) | ||
""" | ||
|
||
attrs = { | ||
|
@@ -127,7 +132,9 @@ async def as_service( | |
logger.info("Spawning Service for %s", cls.__name__) | ||
service = Service(cfg, cls, actor_args, actor_kwargs) | ||
await service.__initialize__() | ||
return ServiceInterface(service, cls) | ||
service_interface = ServiceInterface(service, cls) | ||
await register_service(service_interface) | ||
DNXie marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return service_interface | ||
|
||
@endpoint | ||
async def setup(self): | ||
|
@@ -145,7 +152,7 @@ async def setup(self): | |
pass | ||
|
||
@classmethod | ||
async def launch(cls, *args, **kwargs) -> "ForgeActor": | ||
async def launch(cls, *args, **kwargs) -> "ActorMesh": | ||
"""Provisions and deploys a new actor. | ||
|
||
This method is used by `Service` to provision a new replica. | ||
|
@@ -185,13 +192,16 @@ async def as_actor(cls: Type[T], *args, **actor_kwargs) -> T: | |
""" | ||
logger.info("Spawning single actor %s", cls.__name__) | ||
actor = await cls.launch(*args, **actor_kwargs) | ||
await register_actor(actor) | ||
DNXie marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return actor | ||
|
||
@classmethod | ||
async def shutdown(cls, actor: "ForgeActor"): | ||
async def shutdown(cls, actor: "ForgeActor", quiet: bool = False): | ||
"""Shuts down an actor. | ||
This method is used by `Service` to teardown a replica. | ||
""" | ||
if not quiet: | ||
logger.info(f"Shutting down actor {getattr(actor, 'name', cls.__name__)}") | ||
|
||
if actor._proc_mesh is None: | ||
raise AssertionError("Called shutdown on a replica with no proc_mesh.") | ||
await stop_proc_mesh(actor._proc_mesh) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -131,6 +131,9 @@ def __init__(self, cfg: ProvisionerConfig | None = None): | |
if not self.launcher: | ||
logger.warning("Launcher not provided, remote allocations will not work.") | ||
|
||
self._registered_actors: list["ForgeActor"] = [] | ||
self._registered_services: list["ServiceInterface"] = [] | ||
|
||
async def initialize(self): | ||
"""Call this after creating the instance""" | ||
if self.launcher is not None: | ||
|
@@ -302,8 +305,52 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): | |
commands.kill(server_name) | ||
del self._proc_host_map[proc_mesh] | ||
|
||
def register_service(self, service: "ServiceInterface") -> None: | ||
"""Registers a service allocation for cleanup.""" | ||
from forge.controller.service import ServiceInterface | ||
DNXie marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if not isinstance(service, ServiceInterface): | ||
raise TypeError( | ||
f"register_service expected ServiceInterface, got {type(service)}" | ||
) | ||
|
||
self._registered_services.append(service) | ||
|
||
def register_actor(self, actor: "ForgeActor") -> None: | ||
"""Registers a single actor allocation for cleanup.""" | ||
from monarch._src.actor.actor_mesh import ActorMesh | ||
|
||
|
||
if not isinstance(actor, ActorMesh): | ||
raise TypeError(f"register_actor expected ActorMesh, got {type(actor)}") | ||
|
||
self._registered_actors.append(actor) | ||
|
||
async def shutdown_all_allocations(self): | ||
"""Gracefully shut down all tracked actors and services.""" | ||
# --- ServiceInterface --- | ||
for service in reversed(self._registered_services): | ||
try: | ||
await service.shutdown() | ||
|
||
except Exception as e: | ||
logger.warning(f"Failed to shut down {service}: {e}") | ||
|
||
# --- Actor instance (ForgeActor or underlying ActorMesh) --- | ||
for actor in reversed(self._registered_actors): | ||
try: | ||
# Get the class to call shutdown on (ForgeActor or its bound class) | ||
actor_cls = getattr(actor, "_class", None) or actor.__class__ | ||
await actor_cls.shutdown(actor) | ||
|
||
except Exception as e: | ||
logger.warning(f"Failed to shut down {actor}: {e}") | ||
|
||
self._registered_actors.clear() | ||
self._registered_services.clear() | ||
|
||
async def shutdown(self): | ||
"""Tears down all remaining remote allocations.""" | ||
await self.shutdown_all_allocations() | ||
async with self._lock: | ||
for server_name in self._server_names: | ||
commands.kill(server_name) | ||
|
@@ -372,6 +419,18 @@ async def host_mesh_from_proc(proc_mesh: ProcMesh): | |
return await provisioner.host_mesh_from_proc(proc_mesh) | ||
|
||
|
||
async def register_service(service: "ServiceInterface") -> None: | ||
"""Registers a service allocation with the global provisioner.""" | ||
provisioner = await _get_provisioner() | ||
provisioner.register_service(service) | ||
|
||
|
||
async def register_actor(actor: "ForgeActor") -> None: | ||
"""Registers an actor allocation with the global provisioner.""" | ||
provisioner = await _get_provisioner() | ||
provisioner.register_actor(actor) | ||
|
||
|
||
async def stop_proc_mesh(proc_mesh: ProcMesh): | ||
provisioner = await _get_provisioner() | ||
return await provisioner.stop_proc_mesh(proc_mesh=proc_mesh) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -200,6 +200,7 @@ async def shutdown(self) -> None: | |
""" | ||
Shut down the underlying Service. | ||
""" | ||
logger.info(f"Shutting down service {self.actor_def.__name__}") | ||
|
||
await self._service.stop() | ||
|
||
def session(self) -> "SessionContext": | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DNXie @felipemello1 maybe we can just move the mlogger shutdown into the global shutdown as well?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved it into
shutdown()
.