Skip to content

Commit 11e76b4

Browse files
committed
add sessionloader and leastloadedloader and tests
1 parent 4a5ba51 commit 11e76b4

File tree

5 files changed

+291
-37
lines changed

5 files changed

+291
-37
lines changed

src/forge/controller/service/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,22 @@
66

77
from .interface import ServiceInterface, Session, SessionContext
88
from .metrics import ServiceMetrics
9-
from .replica import Replica, ReplicaMetrics
9+
from .replica import Replica, ReplicaMetrics, ReplicaState
10+
from .router import LeastLoadedRouter, RoundRobinRouter, SessionRouter
1011
from .service import Service, ServiceActor, ServiceConfig
1112

1213
__all__ = [
1314
"Replica",
1415
"ReplicaMetrics",
16+
"ReplicaState",
1517
"Service",
1618
"ServiceConfig",
1719
"ServiceInterface",
1820
"ServiceMetrics",
1921
"Session",
2022
"SessionContext",
2123
"ServiceActor",
24+
"LeastLoadedRouter",
25+
"RoundRobinRouter",
26+
"SessionRouter",
2227
]

src/forge/controller/service/interface.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import logging
1515
from abc import ABC, abstractmethod
1616
from dataclasses import dataclass
17-
from typing import Generic, List, ParamSpec, TypeVar
17+
from typing import Dict, Generic, List, ParamSpec, TypeVar
1818

1919
from monarch._src.actor.endpoint import EndpointProperty
2020

@@ -287,7 +287,10 @@ class Router(ABC):
287287

288288
@abstractmethod
289289
def get_replica(
290-
self, replicas: List[Replica], sess_id: str | None = None
290+
self,
291+
replicas: List[Replica],
292+
sess_id: str | None = None,
293+
session_map: Dict[str, int] | None = None,
291294
) -> Replica:
292295
"""Select a replica from the list based on routing logic."""
293296
pass

src/forge/controller/service/router.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import List
7+
import logging
8+
from typing import Dict, List
89

910
from .interface import Router
1011
from .replica import Replica
1112

13+
logger = logging.getLogger(__name__)
14+
logger.setLevel(logging.DEBUG)
15+
1216

1317
class RoundRobinRouter(Router):
1418
"""Round-robin router for stateless requests."""
@@ -17,7 +21,10 @@ def __init__(self):
1721
self._next_idx = 0
1822

1923
def get_replica(
20-
self, replicas: List[Replica], sess_id: str | None = None
24+
self,
25+
replicas: List[Replica],
26+
sess_id: str | None = None,
27+
session_map: Dict[str, int] | None = None,
2128
) -> Replica:
2229
healthy_replicas = [r for r in replicas if r.healthy]
2330
if not healthy_replicas:
@@ -27,3 +34,57 @@ def get_replica(
2734
replica = healthy_replicas[self._next_idx]
2835

2936
return replica
37+
38+
39+
class LeastLoadedRouter(Router):
40+
"""Always routes to the replica with the lowest current load."""
41+
42+
def get_replica(
43+
self,
44+
replicas: List["Replica"],
45+
sess_id: str | None = None,
46+
session_map: Dict[str, int] | None = None,
47+
) -> "Replica":
48+
healthy_replicas = [r for r in replicas if r.healthy]
49+
if not healthy_replicas:
50+
raise RuntimeError("No healthy replicas available for session assignment")
51+
return min(healthy_replicas, key=lambda r: r.current_load)
52+
53+
54+
class SessionRouter(Router):
55+
"""Session-based routing: sticky sessions with a fallback router."""
56+
57+
def __init__(self, fallback_router: Router):
58+
self.fallback_router = fallback_router
59+
60+
def get_replica(
61+
self,
62+
replicas: List["Replica"],
63+
sess_id: str | None = None,
64+
session_map: Dict[str, int] | None = None,
65+
) -> "Replica":
66+
if sess_id is None:
67+
raise ValueError("SessionRouter requires a session ID")
68+
69+
if session_map is None:
70+
raise ValueError("Session map must be provided for SessionRouter")
71+
72+
# Check if session already has a replica
73+
if sess_id in session_map:
74+
replica_idx = session_map[sess_id]
75+
# Find the replica with this index
76+
for r in replicas:
77+
if r.idx == replica_idx and r.healthy:
78+
return r
79+
# If the replica is no longer healthy, remove from session map and reassign
80+
del session_map[sess_id]
81+
82+
# Use fallback router to assign a new replica
83+
replica = self.fallback_router.get_replica(replicas, sess_id, session_map)
84+
session_map[sess_id] = replica.idx
85+
logger.debug(
86+
"Assigning session %s to replica %d",
87+
sess_id,
88+
replica.idx,
89+
)
90+
return replica

src/forge/controller/service/service.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,20 @@
3636
import logging
3737
import pprint
3838
import uuid
39-
from typing import Dict, List
39+
from typing import Dict, List, Type
4040

4141
from monarch.actor import Actor, endpoint
4242

43-
from forge.controller.service.interface import _session_context, Session
43+
from forge.controller.service.interface import _session_context, Router, Session
4444

4545
from forge.controller.service.metrics import ServiceMetrics
4646
from forge.controller.service.replica import Replica, ServiceRequest
4747

48-
from forge.controller.service.router import RoundRobinRouter
48+
from forge.controller.service.router import (
49+
LeastLoadedRouter,
50+
RoundRobinRouter,
51+
SessionRouter,
52+
)
4953
from forge.types import ServiceConfig
5054

5155
logger = logging.getLogger(__name__)
@@ -64,6 +68,13 @@ class Service:
6468
actor_def: Actor class definition to instantiate on each replica
6569
*actor_args: Positional arguments passed to actor constructor
6670
**actor_kwargs: Keyword arguments passed to actor constructor
71+
router_cls (Type[Router], optional): Router class used for non-session
72+
calls. Defaults to RoundRobinRouter. Examples include RoundRobinRouter
73+
or LeastLoadedRouter. The router is instantiated internally.
74+
fallback_router_cls: Router class used as a fallback when a session
75+
cannot be mapped to an existing replica. Defaults
76+
to LeastLoadedRouter.
77+
6778
6879
Attributes:
6980
_cfg: Service configuration
@@ -73,16 +84,24 @@ class Service:
7384
_endpoints: Dynamically registered actor endpoints
7485
"""
7586

76-
def __init__(self, cfg: ServiceConfig, actor_def, actor_kwargs: dict):
87+
def __init__(
88+
self,
89+
cfg: ServiceConfig,
90+
actor_def,
91+
actor_kwargs: dict,
92+
router_cls: Type["Router"] = RoundRobinRouter,
93+
fallback_router_cls: Type["Router"] = LeastLoadedRouter,
94+
):
7795
self._cfg = cfg
7896
self._replicas = []
7997
self._actor_def = actor_def
8098
self._actor_kwargs = actor_kwargs
99+
self.router_cls = router_cls
100+
self.fallback_router_cls = fallback_router_cls
81101

82102
self._active_sessions = []
83103
self._id_session_map = {}
84104
self._session_replica_map: Dict[str, int] = {}
85-
self._router = RoundRobinRouter()
86105

87106
# Initialize metrics collection
88107
self._metrics = ServiceMetrics()
@@ -95,6 +114,12 @@ def __init__(self, cfg: ServiceConfig, actor_def, actor_kwargs: dict):
95114
async def __initialize__(self):
96115
"""Initializes the service and starts the health loop."""
97116
logger.debug(f"Starting service up with {self._cfg.num_replicas} replicas.")
117+
118+
# Initialize the routers
119+
self._default_router = self.router_cls()
120+
self._session_router = SessionRouter(fallback_router=self.fallback_router_cls())
121+
122+
# Initialize all replicas
98123
replicas = []
99124
num_replicas = self._cfg.num_replicas
100125
for i in range(num_replicas):
@@ -457,36 +482,15 @@ async def _health_loop(self, poll_rate_s: float):
457482

458483
await asyncio.sleep(poll_rate_s)
459484

460-
def _get_least_loaded_replica(self) -> "Replica":
461-
"""Get the replica with the lowest load."""
462-
healthy_replicas = [r for r in self._replicas if r.healthy]
463-
if not healthy_replicas:
464-
raise RuntimeError("No healthy replicas available for session assignment")
465-
466-
# Use the replica's current_load property
467-
return min(healthy_replicas, key=lambda replica: replica.current_load)
468-
469485
async def _get_replica(self, sess_id: str | None) -> "Replica":
470486
"""Get a replica for the given session ID."""
471487
if sess_id is None:
472488
# No session, use the default router
473-
return self._router.get_replica(self._replicas)
474-
475-
# Session-based routing
476-
if sess_id in self._session_replica_map:
477-
replica_idx = self._session_replica_map[sess_id]
478-
# Find the replica with this index
479-
for replica in self._replicas:
480-
if replica.idx == replica_idx and replica.healthy:
481-
return replica
482-
# If the replica is no longer healthy, remove from session map and reassign
483-
del self._session_replica_map[sess_id]
489+
return self._default_router.get_replica(self._replicas)
484490

485-
# New session, assign to least loaded replica
486-
replica = self._get_least_loaded_replica()
487-
self._session_replica_map[sess_id] = replica.idx
488-
logger.debug("Assigning session %s to replica %d", sess_id, replica.idx)
489-
return replica
491+
return self._session_router.get_replica(
492+
self._replicas, sess_id, self._session_replica_map
493+
)
490494

491495
async def stop(self):
492496
logger.debug("Stopping service...")

0 commit comments

Comments
 (0)