-
Notifications
You must be signed in to change notification settings - Fork 16
Auto-track and globally shut down all Forge actors and services #357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
cc5fb5a
9cc5531
84dcd17
8ba3e5f
5b8d81c
5fe99d5
36225af
d705df9
52fa867
f0ba99a
4e477b5
e7ae73d
eb9f667
77c2344
b39fc4b
d4f3d57
0db9ed6
f31aa3a
ffbf0ca
7c5417d
5dc138c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,8 @@ | |
import socket | ||
import uuid | ||
|
||
from monarch._src.actor.actor_mesh import ActorMesh | ||
|
||
from monarch._src.actor.shape import Extent, NDSlice, Shape | ||
from monarch.actor import Actor, endpoint, ProcMesh | ||
|
||
|
@@ -141,6 +143,9 @@ def __init__(self, cfg: ProvisionerConfig | None = None): | |
if not self.launcher: | ||
logger.warning("Launcher not provided, remote allocations will not work.") | ||
|
||
self._registered_actors: list["ForgeActor"] = [] | ||
self._registered_services: list["ServiceInterface"] = [] | ||
|
||
async def initialize(self): | ||
"""Call this after creating the instance""" | ||
if self.launcher is not None: | ||
|
@@ -359,8 +364,55 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): | |
commands.kill(server_name) | ||
del self._proc_host_map[proc_mesh] | ||
|
||
def register_service(self, service: "ServiceInterface") -> None: | ||
"""Registers a service allocation for cleanup.""" | ||
# Import ServiceInterface here instead of at top-level to avoid circular import | ||
from forge.controller.service import ServiceInterface | ||
DNXie marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if not isinstance(service, ServiceInterface): | ||
raise TypeError( | ||
f"register_service expected ServiceInterface, got {type(service)}" | ||
) | ||
|
||
self._registered_services.append(service) | ||
|
||
def register_actor(self, actor: "ForgeActor") -> None: | ||
"""Registers a single actor allocation for cleanup.""" | ||
|
||
if not isinstance(actor, ActorMesh): | ||
raise TypeError(f"register_actor expected ActorMesh, got {type(actor)}") | ||
|
||
self._registered_actors.append(actor) | ||
|
||
async def shutdown_all_allocations(self): | ||
"""Gracefully shut down all tracked actors and services.""" | ||
logger.info( | ||
f"Shutting down {len(self._registered_services)} service(s) and {len(self._registered_actors)} actor(s)..." | ||
) | ||
# --- ServiceInterface --- | ||
for service in reversed(self._registered_services): | ||
try: | ||
await service.shutdown() | ||
|
||
except Exception as e: | ||
logger.warning(f"Failed to shut down {service}: {e}") | ||
|
||
# --- Actor instance (ForgeActor or underlying ActorMesh) --- | ||
for actor in reversed(self._registered_actors): | ||
try: | ||
# Get the class to call shutdown on (ForgeActor or its bound class) | ||
actor_cls = getattr(actor, "_class", None) or actor.__class__ | ||
await actor_cls.shutdown(actor) | ||
|
||
except Exception as e: | ||
logger.warning(f"Failed to shut down {actor}: {e}") | ||
|
||
self._registered_actors.clear() | ||
self._registered_services.clear() | ||
|
||
async def shutdown(self): | ||
"""Tears down all remaining remote allocations.""" | ||
await self.shutdown_all_allocations() | ||
async with self._lock: | ||
for server_name in self._server_names: | ||
commands.kill(server_name) | ||
|
@@ -429,12 +481,43 @@ async def host_mesh_from_proc(proc_mesh: ProcMesh): | |
return await provisioner.host_mesh_from_proc(proc_mesh) | ||
|
||
|
||
async def register_service(service: "ServiceInterface") -> None: | ||
"""Registers a service allocation with the global provisioner.""" | ||
provisioner = await _get_provisioner() | ||
provisioner.register_service(service) | ||
|
||
|
||
async def register_actor(actor: "ForgeActor") -> None: | ||
"""Registers an actor allocation with the global provisioner.""" | ||
provisioner = await _get_provisioner() | ||
provisioner.register_actor(actor) | ||
|
||
|
||
async def stop_proc_mesh(proc_mesh: ProcMesh): | ||
provisioner = await _get_provisioner() | ||
return await provisioner.stop_proc_mesh(proc_mesh=proc_mesh) | ||
|
||
|
||
async def shutdown_metric_logger(): | ||
"""Shutdown the global metric logger and all its backends.""" | ||
from forge.observability.metric_actors import get_or_create_metric_logger | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if this doesn't need to be inline imported, can we do this at the toplevel? This should be a general thing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will cause a circular dependency issue. This is also what Felipe did: https://github.com/meta-pytorch/forge/blob/main/src/forge/controller/provisioner.py#L325-L327 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok sg! I will just assume this to be the case moving forward |
||
|
||
logger.info("Shutting down metric logger...") | ||
try: | ||
mlogger = await get_or_create_metric_logger() | ||
await mlogger.shutdown.call_one() | ||
except Exception as e: | ||
logger.warning(f"Failed to shutdown metric logger: {e}") | ||
|
||
|
||
async def shutdown(): | ||
|
||
await shutdown_metric_logger() | ||
|
||
logger.info("Shutting down provisioner..") | ||
|
||
provisioner = await _get_provisioner() | ||
return await provisioner.shutdown() | ||
result = await provisioner.shutdown() | ||
|
||
logger.info("Shutdown completed successfully") | ||
return result |
Uh oh!
There was an error while loading. Please reload this page.