diff --git a/src/forge/controller/service/__init__.py b/src/forge/controller/service/__init__.py index aa79a48df..f0d8fca7b 100644 --- a/src/forge/controller/service/__init__.py +++ b/src/forge/controller/service/__init__.py @@ -6,12 +6,14 @@ from .interface import ServiceInterface, Session, SessionContext from .metrics import ServiceMetrics -from .replica import Replica, ReplicaMetrics +from .replica import Replica, ReplicaMetrics, ReplicaState +from .router import LeastLoadedRouter, RoundRobinRouter, SessionRouter from .service import Service, ServiceActor, ServiceConfig __all__ = [ "Replica", "ReplicaMetrics", + "ReplicaState", "Service", "ServiceConfig", "ServiceInterface", @@ -19,4 +21,7 @@ "Session", "SessionContext", "ServiceActor", + "LeastLoadedRouter", + "RoundRobinRouter", + "SessionRouter", ] diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index 4c8718a03..a70ec8ad2 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -12,11 +12,14 @@ import contextvars import logging +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Generic, List, ParamSpec, TypeVar +from typing import Dict, Generic, List, ParamSpec, TypeVar from monarch._src.actor.endpoint import EndpointProperty +from .replica import Replica + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -277,3 +280,17 @@ def __getattr__(self, name: str): raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{name}'" ) + + +class Router(ABC): + """Abstract base class for routing logic.""" + + @abstractmethod + def get_replica( + self, + healthy_replicas: List[Replica], + sess_id: str | None = None, + session_map: Dict[str, int] | None = None, + ) -> Replica: + """Select a replica from the list based on routing logic.""" + pass diff --git a/src/forge/controller/service/router.py b/src/forge/controller/service/router.py new file mode 100644 index 000000000..502402e36 --- /dev/null +++ b/src/forge/controller/service/router.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Dict, List + +from .interface import Router +from .replica import Replica + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class RoundRobinRouter(Router): + """Round-robin router for stateless requests.""" + + def __init__(self): + self._next_idx = 0 + + def get_replica( + self, + healthy_replicas: List[Replica], + sess_id: str | None = None, + session_map: Dict[str, int] | None = None, + ) -> Replica: + if not healthy_replicas: + raise RuntimeError("No healthy replicas available for load balancing") + + self._next_idx = (self._next_idx + 1) % len(healthy_replicas) + replica = healthy_replicas[self._next_idx] + + return replica + + +class LeastLoadedRouter(Router): + """Always routes to the replica with the lowest current load.""" + + def get_replica( + self, + healthy_replicas: List[Replica], + sess_id: str | None = None, + session_map: Dict[str, int] | None = None, + ) -> Replica: + if not healthy_replicas: + raise RuntimeError("No healthy replicas available for session assignment") + return min(healthy_replicas, key=lambda r: r.current_load) + + +class SessionRouter(Router): + """Session-based routing: sticky sessions with a fallback router.""" + + def __init__(self, fallback_router: Router): + self.fallback_router = fallback_router + + def get_replica( + self, + healthy_replicas: List[Replica], + sess_id: str | None = None, + session_map: Dict[str, int] | None = None, + ) -> Replica: + if sess_id is None: + raise ValueError("SessionRouter requires a session ID") + + if session_map is None: + raise ValueError("Session map must be provided for SessionRouter") + + # Check if session already has a replica + if sess_id in session_map: + replica_idx = session_map[sess_id] + # Find the replica with this index + for r in healthy_replicas: + if r.idx == replica_idx: + return r + # If the replica is no longer healthy, remove from session map and reassign + del session_map[sess_id] + + # Use fallback router to assign a new replica + replica = self.fallback_router.get_replica( + healthy_replicas, sess_id, session_map + ) + session_map[sess_id] = replica.idx + logger.debug( + "Assigning session %s to replica %d", + sess_id, + replica.idx, + ) + return replica diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index ede58c821..2b8d8ab9c 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -44,6 +44,12 @@ from forge.controller.service.metrics import ServiceMetrics from forge.controller.service.replica import Replica, ServiceRequest + +from forge.controller.service.router import ( + LeastLoadedRouter, + RoundRobinRouter, + SessionRouter, +) from forge.types import ServiceConfig logger = logging.getLogger(__name__) @@ -63,6 +69,7 @@ class Service: *actor_args: Positional arguments passed to actor constructor **actor_kwargs: Keyword arguments passed to actor constructor + Attributes: _cfg: Service configuration _replicas: List of managed replica instances @@ -71,7 +78,12 @@ class Service: _endpoints: Dynamically registered actor endpoints """ - def __init__(self, cfg: ServiceConfig, actor_def, actor_kwargs: dict): + def __init__( + self, + cfg: ServiceConfig, + actor_def, + actor_kwargs: dict, + ): self._cfg = cfg self._replicas = [] self._actor_def = actor_def @@ -80,7 +92,6 @@ def __init__(self, cfg: ServiceConfig, actor_def, actor_kwargs: dict): self._active_sessions = [] self._id_session_map = {} self._session_replica_map: Dict[str, int] = {} - self._next_replica_idx = 0 # For round-robin load balancing # Initialize metrics collection self._metrics = ServiceMetrics() @@ -93,6 +104,12 @@ def __init__(self, cfg: ServiceConfig, actor_def, actor_kwargs: dict): async def __initialize__(self): """Initializes the service and starts the health loop.""" logger.debug(f"Starting service up with {self._cfg.num_replicas} replicas.") + + # Initialize the routers + self._default_router = RoundRobinRouter() + self._session_router = SessionRouter(fallback_router=LeastLoadedRouter()) + + # Initialize all replicas replicas = [] num_replicas = self._cfg.num_replicas for i in range(num_replicas): @@ -455,47 +472,16 @@ async def _health_loop(self, poll_rate_s: float): await asyncio.sleep(poll_rate_s) - def _get_next_replica(self) -> "Replica": - """Get the next replica using round-robin selection.""" - healthy_replicas = [r for r in self._replicas if r.healthy] - if not healthy_replicas: - raise RuntimeError("No healthy replicas available for load balancing") - - # Simple round-robin - self._next_replica_idx = (self._next_replica_idx + 1) % len(healthy_replicas) - return healthy_replicas[self._next_replica_idx] - - def _get_least_loaded_replica(self) -> "Replica": - """Get the replica with the lowest load.""" - healthy_replicas = [r for r in self._replicas if r.healthy] - if not healthy_replicas: - raise RuntimeError("No healthy replicas available for session assignment") - - # Use the replica's current_load property - return min(healthy_replicas, key=lambda replica: replica.current_load) - async def _get_replica(self, sess_id: str | None) -> "Replica": """Get a replica for the given session ID.""" + healthy_replicas = [r for r in self._replicas if r.healthy] if sess_id is None: - # No session, use round-robin load balancing - replica = self._get_next_replica() - return replica + # No session, use the default router + return self._default_router.get_replica(healthy_replicas) - # Session-based routing - if sess_id in self._session_replica_map: - replica_idx = self._session_replica_map[sess_id] - # Find the replica with this index - for replica in self._replicas: - if replica.idx == replica_idx and replica.healthy: - return replica - # If the replica is no longer healthy, remove from session map and reassign - del self._session_replica_map[sess_id] - - # New session, assign to least loaded replica - replica = self._get_least_loaded_replica() - self._session_replica_map[sess_id] = replica.idx - logger.debug("Assigning session %s to replica %d", sess_id, replica.idx) - return replica + return self._session_router.get_replica( + healthy_replicas, sess_id, self._session_replica_map + ) async def stop(self): logger.debug("Stopping service...") @@ -592,7 +578,6 @@ async def _get_internal_state(self) -> dict: for replica in self._replicas ], # Load balancing state - "next_replica_idx": self._next_replica_idx, # Service-level state "total_replicas": len(self._replicas), "healthy_replica_count": sum(1 for r in self._replicas if r.healthy), diff --git a/tests/unit_tests/test_service.py b/tests/unit_tests/test_service.py index fb8504ed2..ee3f39eb0 100644 --- a/tests/unit_tests/test_service.py +++ b/tests/unit_tests/test_service.py @@ -13,8 +13,15 @@ import pytest from forge.controller import ForgeActor - -from forge.controller.service import ServiceConfig +from forge.controller.service import ( + LeastLoadedRouter, + Replica, + ReplicaState, + RoundRobinRouter, + ServiceConfig, + SessionRouter, +) +from forge.types import ProcessConfig from monarch.actor import Actor, endpoint logger = logging.getLogger(__name__) @@ -56,6 +63,19 @@ async def add_to_value(self, amount: int, multiplier: int = 1) -> int: return self.v +def make_replica(idx: int, healthy: bool = True, load: int = 0) -> Replica: + """Helper to build a replica with specified state and load.""" + replica = Replica( + idx=idx, + proc_config=ProcessConfig(), + actor_def=Counter, + actor_kwargs={}, + ) + replica.state = ReplicaState.HEALTHY if healthy else ReplicaState.UNHEALTHY + replica.active_requests = load + return replica + + # Core Functionality Tests @@ -634,3 +654,95 @@ async def test_broadcast_call_vs_choose(): finally: await service.shutdown() + + +# Router Tests + + +@pytest.mark.asyncio +async def test_session_router_with_round_robin_fallback(): + """Switch fallback router to round-robin and verify assignment order.""" + # Choose RoundRobinRouter as fallback, r1 and r2 should be assigned to different replicas + replicas = [make_replica(0, load=0), make_replica(1, load=5)] + session_map = {} + fallback = RoundRobinRouter() + router = SessionRouter(fallback) + + r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) + r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map) + + assert r1.idx != r2.idx + assert set(session_map.values()) == {0, 1} + + # If LeastLoadedRouter as fallback, r1 and r2 should be assigned to same replicas + replicas = [make_replica(0, load=0), make_replica(1, load=5)] + session_map = {} + fallback = LeastLoadedRouter() + router = SessionRouter(fallback) + + r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) + r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map) + + assert r1.idx == r2.idx == 0 + + +# Router integeration tests + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_round_robin_router_distribution(): + """Test that the RoundRobinRouter distributes sessionless calls evenly across replicas.""" + service = await Counter.options(procs_per_replica=1, num_replicas=3).as_service(v=0) + + try: + # Make multiple sessionless calls using choose() + results = [] + for _ in range(6): + await service.incr.choose() + values = await service.value.call() + print(values) + results.append(values) + print("results: ", results) + # Verify that requests were distributed round-robin + # Each call increments a single replica, so after 6 calls we expect: + # 2 increments per replica (since 3 replicas, 6 calls) + final_values = results[-1] # last snapshot + assert sorted(final_values) == [2, 2, 2] + + finally: + await service.shutdown() + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_session_router_assigns_and_updates_session_map_in_service(): + """Integration: Service with SessionRouter preserves sticky sessions.""" + # Use LeastLoaded as default, SessionRouter (with fallback) is always active + service = await Counter.options( + procs_per_replica=1, + num_replicas=2, + ).as_service(v=0) + + try: + # First call with sess_id -> assign a replica + await service.incr.choose(sess_id="sess1") + values1 = await service.value.call() + + # Second call with same sess_id -> must hit same replica + await service.incr.choose(sess_id="sess1") + values2 = await service.value.call() + + # Difference should only be on one replica (sticky session) + diffs = [v2 - v1 for v1, v2 in zip(values1, values2)] + assert ( + sum(diffs) == 1 + ), f"Expected exactly one replica to increment, got {diffs}" + assert max(diffs) == 1 and min(diffs) == 0 + + # Session map in service should reflect assigned replica + assigned_idx = service._session_replica_map["sess1"] + assert values2[assigned_idx] == values1[assigned_idx] + 1 + + finally: + await service.shutdown()