3636import  logging 
3737import  pprint 
3838import  uuid 
39- from  typing  import  Dict , List 
39+ from  typing  import  Dict , List ,  Type 
4040
4141from  monarch .actor  import  Actor , endpoint 
4242
43- from  forge .controller .service .interface  import  _session_context , Session 
43+ from  forge .controller .service .interface  import  _session_context , Router ,  Session 
4444
4545from  forge .controller .service .metrics  import  ServiceMetrics 
4646from  forge .controller .service .replica  import  Replica , ServiceRequest 
4747
48- from  forge .controller .service .router  import  RoundRobinRouter 
48+ from  forge .controller .service .router  import  (
49+     LeastLoadedRouter ,
50+     RoundRobinRouter ,
51+     SessionRouter ,
52+ )
4953from  forge .types  import  ServiceConfig 
5054
5155logger  =  logging .getLogger (__name__ )
@@ -64,6 +68,13 @@ class Service:
6468        actor_def: Actor class definition to instantiate on each replica 
6569        *actor_args: Positional arguments passed to actor constructor 
6670        **actor_kwargs: Keyword arguments passed to actor constructor 
71+         router_cls (Type[Router], optional): Router class used for non-session 
72+             calls. Defaults to RoundRobinRouter. Examples include RoundRobinRouter 
73+             or LeastLoadedRouter. The router is instantiated internally. 
74+         fallback_router_cls: Router class used as a fallback when a session 
75+                              cannot be mapped to an existing replica. Defaults 
76+                              to LeastLoadedRouter. 
77+ 
6778
6879    Attributes: 
6980        _cfg: Service configuration 
@@ -73,16 +84,24 @@ class Service:
7384        _endpoints: Dynamically registered actor endpoints 
7485    """ 
7586
76-     def  __init__ (self , cfg : ServiceConfig , actor_def , actor_kwargs : dict ):
87+     def  __init__ (
88+         self ,
89+         cfg : ServiceConfig ,
90+         actor_def ,
91+         actor_kwargs : dict ,
92+         router_cls : Type ["Router" ] =  RoundRobinRouter ,
93+         fallback_router_cls : Type ["Router" ] =  LeastLoadedRouter ,
94+     ):
7795        self ._cfg  =  cfg 
7896        self ._replicas  =  []
7997        self ._actor_def  =  actor_def 
8098        self ._actor_kwargs  =  actor_kwargs 
99+         self .router_cls  =  router_cls 
100+         self .fallback_router_cls  =  fallback_router_cls 
81101
82102        self ._active_sessions  =  []
83103        self ._id_session_map  =  {}
84104        self ._session_replica_map : Dict [str , int ] =  {}
85-         self ._router  =  RoundRobinRouter ()
86105
87106        # Initialize metrics collection 
88107        self ._metrics  =  ServiceMetrics ()
@@ -95,6 +114,12 @@ def __init__(self, cfg: ServiceConfig, actor_def, actor_kwargs: dict):
95114    async  def  __initialize__ (self ):
96115        """Initializes the service and starts the health loop.""" 
97116        logger .debug (f"Starting service up with { self ._cfg .num_replicas }   replicas." )
117+ 
118+         # Initialize the routers 
119+         self ._default_router  =  self .router_cls ()
120+         self ._session_router  =  SessionRouter (fallback_router = self .fallback_router_cls ())
121+ 
122+         # Initialize all replicas 
98123        replicas  =  []
99124        num_replicas  =  self ._cfg .num_replicas 
100125        for  i  in  range (num_replicas ):
@@ -457,36 +482,15 @@ async def _health_loop(self, poll_rate_s: float):
457482
458483            await  asyncio .sleep (poll_rate_s )
459484
460-     def  _get_least_loaded_replica (self ) ->  "Replica" :
461-         """Get the replica with the lowest load.""" 
462-         healthy_replicas  =  [r  for  r  in  self ._replicas  if  r .healthy ]
463-         if  not  healthy_replicas :
464-             raise  RuntimeError ("No healthy replicas available for session assignment" )
465- 
466-         # Use the replica's current_load property 
467-         return  min (healthy_replicas , key = lambda  replica : replica .current_load )
468- 
469485    async  def  _get_replica (self , sess_id : str  |  None ) ->  "Replica" :
470486        """Get a replica for the given session ID.""" 
471487        if  sess_id  is  None :
472488            # No session, use the default router 
473-             return  self ._router .get_replica (self ._replicas )
474- 
475-         # Session-based routing 
476-         if  sess_id  in  self ._session_replica_map :
477-             replica_idx  =  self ._session_replica_map [sess_id ]
478-             # Find the replica with this index 
479-             for  replica  in  self ._replicas :
480-                 if  replica .idx  ==  replica_idx  and  replica .healthy :
481-                     return  replica 
482-             # If the replica is no longer healthy, remove from session map and reassign 
483-             del  self ._session_replica_map [sess_id ]
489+             return  self ._default_router .get_replica (self ._replicas )
484490
485-         # New session, assign to least loaded replica 
486-         replica  =  self ._get_least_loaded_replica ()
487-         self ._session_replica_map [sess_id ] =  replica .idx 
488-         logger .debug ("Assigning session %s to replica %d" , sess_id , replica .idx )
489-         return  replica 
491+         return  self ._session_router .get_replica (
492+             self ._replicas , sess_id , self ._session_replica_map 
493+         )
490494
491495    async  def  stop (self ):
492496        logger .debug ("Stopping service..." )
0 commit comments