Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/forge/controller/service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,22 @@

from .interface import ServiceInterface, Session, SessionContext
from .metrics import ServiceMetrics
from .replica import Replica, ReplicaMetrics
from .replica import Replica, ReplicaMetrics, ReplicaState
from .router import LeastLoadedRouter, RoundRobinRouter, SessionRouter
from .service import Service, ServiceActor, ServiceConfig

__all__ = [
"Replica",
"ReplicaMetrics",
"ReplicaState",
"Service",
"ServiceConfig",
"ServiceInterface",
"ServiceMetrics",
"Session",
"SessionContext",
"ServiceActor",
"LeastLoadedRouter",
"RoundRobinRouter",
"SessionRouter",
]
19 changes: 18 additions & 1 deletion src/forge/controller/service/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@

import contextvars
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, List, ParamSpec, TypeVar
from typing import Dict, Generic, List, ParamSpec, TypeVar

from monarch._src.actor.endpoint import EndpointProperty

from .replica import Replica

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

Expand Down Expand Up @@ -277,3 +280,17 @@ def __getattr__(self, name: str):
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)


class Router(ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - I think this Router can actually just be in router.py

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we cam keep it in interface.py since that's where all the interfaces are.

"""Abstract base class for routing logic."""

@abstractmethod
def get_replica(
self,
healthy_replicas: List[Replica],
sess_id: str | None = None,
session_map: Dict[str, int] | None = None,
) -> Replica:
"""Select a replica from the list based on routing logic."""
pass
90 changes: 90 additions & 0 deletions src/forge/controller/service/router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Dict, List

from .interface import Router
from .replica import Replica

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


class RoundRobinRouter(Router):
"""Round-robin router for stateless requests."""

def __init__(self):
self._next_idx = 0

def get_replica(
self,
healthy_replicas: List[Replica],
sess_id: str | None = None,
session_map: Dict[str, int] | None = None,
) -> Replica:
if not healthy_replicas:
raise RuntimeError("No healthy replicas available for load balancing")

self._next_idx = (self._next_idx + 1) % len(healthy_replicas)
replica = healthy_replicas[self._next_idx]

return replica


class LeastLoadedRouter(Router):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: BalancedRouter?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think LeastConnectedRouter would be the most canonically accurate

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion!

See Allen's discussion on this in another thread.

I think LeastConnectedRouter would be the most canonically accurate

I agree. So let's keep it as LeastConnectedRouter for now.

"""Always routes to the replica with the lowest current load."""

def get_replica(
self,
healthy_replicas: List[Replica],
sess_id: str | None = None,
session_map: Dict[str, int] | None = None,
) -> Replica:
if not healthy_replicas:
raise RuntimeError("No healthy replicas available for session assignment")
return min(healthy_replicas, key=lambda r: r.current_load)


class SessionRouter(Router):
"""Session-based routing: sticky sessions with a fallback router."""

def __init__(self, fallback_router: Router):
self.fallback_router = fallback_router

def get_replica(
self,
healthy_replicas: List[Replica],
sess_id: str | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sess_id: str,

i.e. don't assume None is passable for this or session_map. I would also get rid of the checks below

Copy link
Member Author

@DNXie DNXie Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because I have a interface for routers: (also see interface.py)

class Router(ABC):
    """Abstract base class for routing logic."""

    @abstractmethod
    def get_replica(
        self,
        healthy_replicas: List[Replica],
        sess_id: str | None = None,
        session_map: Dict[str, int] | None = None,
    ) -> Replica:
        """Select a replica from the list based on routing logic."""
        pass

For RRRouter and LeastLoadedRouter, this could be None.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm ok, I'm not sure how I feel about that longer term but that's ok for now

session_map: Dict[str, int] | None = None,
) -> Replica:
if sess_id is None:
raise ValueError("SessionRouter requires a session ID")

if session_map is None:
raise ValueError("Session map must be provided for SessionRouter")

# Check if session already has a replica
if sess_id in session_map:
replica_idx = session_map[sess_id]
# Find the replica with this index
for r in healthy_replicas:
if r.idx == replica_idx:
return r
# If the replica is no longer healthy, remove from session map and reassign
del session_map[sess_id]

# Use fallback router to assign a new replica
replica = self.fallback_router.get_replica(
healthy_replicas, sess_id, session_map
)
session_map[sess_id] = replica.idx
logger.debug(
"Assigning session %s to replica %d",
sess_id,
replica.idx,
)
return replica
65 changes: 25 additions & 40 deletions src/forge/controller/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@

from forge.controller.service.metrics import ServiceMetrics
from forge.controller.service.replica import Replica, ServiceRequest

from forge.controller.service.router import (
LeastLoadedRouter,
RoundRobinRouter,
SessionRouter,
)
from forge.types import ServiceConfig

logger = logging.getLogger(__name__)
Expand All @@ -63,6 +69,7 @@ class Service:
*actor_args: Positional arguments passed to actor constructor
**actor_kwargs: Keyword arguments passed to actor constructor


Attributes:
_cfg: Service configuration
_replicas: List of managed replica instances
Expand All @@ -71,7 +78,12 @@ class Service:
_endpoints: Dynamically registered actor endpoints
"""

def __init__(self, cfg: ServiceConfig, actor_def, actor_kwargs: dict):
def __init__(
self,
cfg: ServiceConfig,
actor_def,
actor_kwargs: dict,
):
self._cfg = cfg
self._replicas = []
self._actor_def = actor_def
Expand All @@ -80,7 +92,6 @@ def __init__(self, cfg: ServiceConfig, actor_def, actor_kwargs: dict):
self._active_sessions = []
self._id_session_map = {}
self._session_replica_map: Dict[str, int] = {}
self._next_replica_idx = 0 # For round-robin load balancing

# Initialize metrics collection
self._metrics = ServiceMetrics()
Expand All @@ -93,6 +104,12 @@ def __init__(self, cfg: ServiceConfig, actor_def, actor_kwargs: dict):
async def __initialize__(self):
"""Initializes the service and starts the health loop."""
logger.debug(f"Starting service up with {self._cfg.num_replicas} replicas.")

# Initialize the routers
self._default_router = RoundRobinRouter()
self._session_router = SessionRouter(fallback_router=LeastLoadedRouter())

# Initialize all replicas
replicas = []
num_replicas = self._cfg.num_replicas
for i in range(num_replicas):
Expand Down Expand Up @@ -455,47 +472,16 @@ async def _health_loop(self, poll_rate_s: float):

await asyncio.sleep(poll_rate_s)

def _get_next_replica(self) -> "Replica":
"""Get the next replica using round-robin selection."""
healthy_replicas = [r for r in self._replicas if r.healthy]
if not healthy_replicas:
raise RuntimeError("No healthy replicas available for load balancing")

# Simple round-robin
self._next_replica_idx = (self._next_replica_idx + 1) % len(healthy_replicas)
return healthy_replicas[self._next_replica_idx]

def _get_least_loaded_replica(self) -> "Replica":
"""Get the replica with the lowest load."""
healthy_replicas = [r for r in self._replicas if r.healthy]
if not healthy_replicas:
raise RuntimeError("No healthy replicas available for session assignment")

# Use the replica's current_load property
return min(healthy_replicas, key=lambda replica: replica.current_load)

async def _get_replica(self, sess_id: str | None) -> "Replica":
"""Get a replica for the given session ID."""
healthy_replicas = [r for r in self._replicas if r.healthy]
if sess_id is None:
# No session, use round-robin load balancing
replica = self._get_next_replica()
return replica
# No session, use the default router
return self._default_router.get_replica(healthy_replicas)

# Session-based routing
if sess_id in self._session_replica_map:
replica_idx = self._session_replica_map[sess_id]
# Find the replica with this index
for replica in self._replicas:
if replica.idx == replica_idx and replica.healthy:
return replica
# If the replica is no longer healthy, remove from session map and reassign
del self._session_replica_map[sess_id]

# New session, assign to least loaded replica
replica = self._get_least_loaded_replica()
self._session_replica_map[sess_id] = replica.idx
logger.debug("Assigning session %s to replica %d", sess_id, replica.idx)
return replica
return self._session_router.get_replica(
healthy_replicas, sess_id, self._session_replica_map
)

async def stop(self):
logger.debug("Stopping service...")
Expand Down Expand Up @@ -592,7 +578,6 @@ async def _get_internal_state(self) -> dict:
for replica in self._replicas
],
# Load balancing state
"next_replica_idx": self._next_replica_idx,
# Service-level state
"total_replicas": len(self._replicas),
"healthy_replica_count": sum(1 for r in self._replicas if r.healthy),
Expand Down
Loading
Loading