Skip to content
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/forge/controller/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .actor import ForgeActor
from .interface import ServiceInterface, Session, SessionContext
from .proc_mesh import get_proc_mesh, spawn_actors
from .service import Service, ServiceConfig
from .spawn import spawn_service

__all__ = [
"Service",
"ServiceConfig",
"ServiceInterface",
"Session",
"SessionContext",
"spawn_service",
"spawn_actors",
"get_proc_mesh",
Expand Down
184 changes: 184 additions & 0 deletions src/forge/controller/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Service interface and session management.

This module provides the user-facing API for interacting with distributed services,
including session management, context propagation, and dynamic endpoint registration.
"""

import contextvars
import logging
from dataclasses import dataclass
from typing import Generic, List, ParamSpec, TypeVar

from monarch._src.actor.endpoint import EndpointProperty

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

P = ParamSpec("P")
R = TypeVar("R")


@dataclass
class Session:
"""Simple session data holder."""

session_id: str


# Context variable for session state
_session_context = contextvars.ContextVar("session_context")


class SessionContext:
"""
Async context manager for stateful service sessions with automatic lifecycle management.

Provides a convenient way to maintain stateful connections to replicas across multiple
requests. Sessions ensure that all requests within the context are routed to the same
replica, enabling stateful interactions while handling session lifecycle automatically.

Example:

>>> async with service.session() as session:
... # All calls within this block use the same replica
... result1 = await service.my_endpoint(arg1)
... result2 = await service.another_endpoint(result1)

"""

def __init__(self, service: "ServiceInterface"):
self.service = service
self.session_id: str | None = None
self._token = None

async def __aenter__(self):
"""Start a session and set context variables."""
self.session_id = await self.service.start_session()
# Set context for this async task
context_value = {"session_id": self.session_id}
self._token = _session_context.set(context_value)
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Terminate the session and restore context."""
if self._token:
_session_context.reset(self._token)
if self.session_id:
await self.service.terminate_session(self.session_id)
self.session_id = None


class ServiceEndpoint(Generic[P, R]):
"""An endpoint object specific to services.

This loosely mimics the Endpoint APIs exposed in Monarch, with
a few key differences:
- Only choose and call are retained (dropping stream and call_one)
- Call returns a list directly rather than a ValueMesh.

These changes are made with Forge use cases in mind, but can
certainly be expanded/adapted in the future.

"""

def __init__(self, actor_mesh, endpoint_name: str):
self.actor_mesh = actor_mesh
self.endpoint_name = endpoint_name

async def choose(
self, *args: P.args, sess_id: str | None = None, **kwargs: P.kwargs
) -> R:
"""Chooses a replica to call based on context and load balancing strategy."""
return await self.actor_mesh._call.call_one(
sess_id, self.endpoint_name, *args, **kwargs
)

async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]:
"""Broadcasts a request to all healthy replicas and returns the results as a list."""
result = await self.actor_mesh._call_all.call_one(
self.endpoint_name, *args, **kwargs
)
return result


class ServiceInterface:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a comment in the docstring, but a few reasons:

  • primarily so you don't have to interact with the Service actor, which IMO can be annoying
  • pairs the proc_mesh with the Service actor
  • later we might want to pass a reference to the service to other actors, without moving where it's placed. This doubles as a handle so we can make calls to it

"""
A lightweight interface to a Service Actor running on a single-node mesh.

This interface holds references to the proc_mesh and actor_mesh (both of size 1)
and exposes its user-defined actor endpoints as ServiceEndpoint objects that
route through the Service Actor's _call and _call_all endpoints.

The ServiceInterface acts as the handle that is returned to end clients,
providing a simple interface that makes actual calls to the Service Actor.
"""

def __init__(self, _proc_mesh, _service, actor_def):
self._proc_mesh = _proc_mesh
self._service = _service
self.actor_def = actor_def

# Dynamically create ServiceEndpoint objects for user's actor endpoints
# Inspect the actor_def directly to find endpoints
for attr_name in dir(actor_def):
attr_value = getattr(actor_def, attr_name)
if isinstance(attr_value, EndpointProperty):
# Create a ServiceEndpoint that will route through the Service Actor
endpoint = ServiceEndpoint(self._service, attr_name)
setattr(self, attr_name, endpoint)

# Session management methods - handled by ServiceInterface
async def start_session(self) -> str:
"""Starts a new session for stateful request handling."""
return await self._service.start_session.call_one()

async def terminate_session(self, sess_id: str):
"""Terminates an active session and cleans up associated resources."""
return await self._service.terminate_session.call_one(sess_id)

def session(self) -> "SessionContext":
"""Returns a context manager for session-based calls."""
return SessionContext(self)

# Service control methods - forwarded to Service Actor
async def stop(self):
"""Stops the service gracefully."""
# First stop the service
await self._service.stop.call_one()
# Then stop its underlying proc
await self._proc_mesh.stop()

# Metrics methods - forwarded to Service Actor
async def get_metrics(self):
"""Get comprehensive service metrics for monitoring and analysis."""
return await self._service.get_metrics.call_one()

async def get_metrics_summary(self):
"""Get a summary of key metrics for monitoring and debugging."""
return await self._service.get_metrics_summary.call_one()

# Testing method - forwarded to Service Actor
def _get_internal_state(self):
"""
Get comprehensive internal state for testing purposes.

Returns:
dict: Complete internal state including sessions, replicas, and metrics
"""
return self._service._get_internal_state.call_one()

def __getattr__(self, name: str):
"""Forward all other attribute access to the underlying Service Actor."""
# Forward everything else to the _service
if hasattr(self._service, name):
return getattr(self._service, name)

raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
73 changes: 73 additions & 0 deletions src/forge/controller/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Service metrics collection and aggregation.

This module provides comprehensive metrics tracking for distributed services,
including per-replica performance data, service-wide aggregations, and
health status information.
"""

from dataclasses import dataclass, field
from typing import Dict, List

from forge.controller.replica import ReplicaMetrics


# TODO - tie this into metrics logger when it exists.
@dataclass
class ServiceMetrics:
"""
Aggregated metrics collection for the entire service.

Provides service-wide visibility into performance, health, and scaling metrics
by aggregating data from all replica instances.

Attributes:
replica_metrics: Per-replica metrics indexed by replica ID
total_sessions: Number of active sessions across all replicas
healthy_replicas: Number of currently healthy replicas
total_replicas: Total number of replicas (healthy + unhealthy)
last_scale_event: Timestamp of the last scaling operation
"""

# Replica metrics
replica_metrics: Dict[int, ReplicaMetrics] = field(default_factory=dict)
# Service-level metrics
total_sessions: int = 0
healthy_replicas: int = 0
total_replicas: int = 0
# Time-based metrics
last_scale_event: float = 0.0

def get_total_request_rate(self, window_seconds: float = 60.0) -> float:
"""Get total requests per second across all replicas."""
return sum(
metrics.get_request_rate(window_seconds)
for metrics in self.replica_metrics.values()
)

def get_avg_queue_depth(self, replicas: List) -> float:
"""Get average queue depth across all healthy replicas."""
healthy_replicas = [r for r in replicas if r.healthy]
if not healthy_replicas:
return 0.0
total_queue_depth = sum(r.request_queue.qsize() for r in healthy_replicas)
return total_queue_depth / len(healthy_replicas)

def get_avg_capacity_utilization(self, replicas: List) -> float:
"""Get average capacity utilization across all healthy replicas."""
healthy_replicas = [r for r in replicas if r.healthy]
if not healthy_replicas:
return 0.0
total_utilization = sum(r.capacity_utilization for r in healthy_replicas)
return total_utilization / len(healthy_replicas)

def get_sessions_per_replica(self) -> float:
"""Get average sessions per replica."""
if self.total_replicas == 0:
return 0.0
return self.total_sessions / self.total_replicas
43 changes: 19 additions & 24 deletions src/forge/controller/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,22 +199,23 @@ async def _do_recovery():
if old_proc_mesh is not None:
try:
await old_proc_mesh.stop()
logger.debug("Old proc_mesh stopped for replica %d", self.idx)
logger.debug(f"Old proc_mesh stopped for replica {self.idx}")
except Exception as e:
logger.warning(
"Error stopping old proc_mesh for replica %d: %s", self.idx, e
f"Error stopping old proc_mesh for replica {self.idx}: {e}"
)

try:
logger.debug("Creating new proc_mesh for replica %d", self.idx)
logger.debug(f"Creating new proc_mesh for replica {self.idx}")
await self.initialize()
logger.debug("Recovery completed successfully for replica %d", self.idx)
logger.debug(f"Recovery completed successfully for replica {self.idx}")
except Exception as e:
logger.error("Recovery failed for replica %d: %s", self.idx, e)
logger.error(f"Recovery failed for replica {self.idx}: {e}")
self.state = ReplicaState.UNHEALTHY
raise

logger.debug("Starting recovery for replica %d", self.idx)
logger.debug(f"Starting recovery for replica {self.idx}")
self.state = ReplicaState.RECOVERING
self._recovery_task = asyncio.create_task(_do_recovery())
await self._recovery_task

Expand All @@ -223,15 +224,15 @@ async def create_proc_mesh(self):
# TODO - for policy replica, we would override this method to
# include multiple proc_meshes
if self.proc_mesh is not None:
logger.warning("Proc mesh already initialized for replica %d", self.idx)
logger.warning(f"Proc mesh already initialized for replica {self.idx}")
return

logger.debug("Creating proc_mesh for replica %d", self.idx)
logger.debug(f"Creating proc_mesh for replica {self.idx}")
try:
self.proc_mesh = await get_proc_mesh(process_config=self.proc_config)
logger.debug("Proc mesh created successfully for replica %d", self.idx)
logger.debug(f"Proc mesh created successfully for replica {self.idx}")
except Exception as e:
logger.error("Failed to create proc_mesh for replica %d: %s", self.idx, e)
logger.error(f"Failed to create proc_mesh for replica {self.idx}: {e}")
self.state = ReplicaState.UNHEALTHY
raise

Expand Down Expand Up @@ -262,10 +263,10 @@ async def spawn_actor(self, actor_def, *actor_args, **actor_kwargs):
if setup_method := getattr(self.actor, "setup", None):
await setup_method.call()

logger.debug("Actor spawned successfully on replica %d", self.idx)
logger.debug(f"Actor spawned successfully on replica {self.idx}")

except Exception as e:
logger.error("Failed to spawn actor on replica %d: %s", self.idx, e)
logger.error(f"Failed to spawn actor on replica {self.idx}: {e}")
self.mark_failed()
raise

Expand All @@ -275,7 +276,7 @@ def start_processing(self):
"""Start the replica's processing loop if not already running."""
if self._run_task is None or self._run_task.done():
self._run_task = asyncio.create_task(self.run())
logger.debug("Started processing loop for replica %d", self.idx)
logger.debug(f"Started processing loop for replica {self.idx}")

async def enqueue_request(self, request: ServiceRequest):
"""Enqueues a request for processing by this replica."""
Expand Down Expand Up @@ -317,7 +318,7 @@ async def _process_single_request(self, request: ServiceRequest) -> bool:
result = result._values[0]
request.future.set_result(result)
except ActorError as e:
logger.warning("Got failure on replica %d. Error:\n%s", self.idx, e)
logger.warning(f"Got failure on replica {self.idx}. Error:\n{e}")
# The exception came from the actor. It itself is
# returned to be propagated through the services
# back to the caller.
Expand All @@ -329,9 +330,7 @@ async def _process_single_request(self, request: ServiceRequest) -> bool:
self.mark_failed()
success = False
except Exception as e:
logger.debug(
"Got unexpected error on replica %d. Error:\n%s", self.idx, e
)
logger.debug(f"Got unexpected error on replica {self.idx}. Error:\n{e}")
self.mark_failed()

# The exception was not from the actor - in this case
Expand Down Expand Up @@ -377,17 +376,13 @@ async def run(self):
continue

except Exception as e:
logger.error(
"Error in replica %d processing loop: %s",
self.idx,
e,
)
logger.error(f"Error in replica {self.idx} processing loop: {e}")
self.state = ReplicaState.UNHEALTHY
break

finally:
self._running = False
logger.debug("Replica %d stopped processing", self.idx)
logger.debug(f"Replica {self.idx} stopped processing")

# Replica state management

Expand Down Expand Up @@ -418,7 +413,7 @@ def failed(self) -> bool:

def mark_failed(self):
"""Mark the replica as failed, triggering recovery."""
logger.debug("Marking replica %d as failed", self.idx)
logger.debug(f"Marking replica {self.idx} as failed")
self.state = ReplicaState.RECOVERING

async def stop(self):
Expand Down
Loading
Loading