Skip to content

Commit 7d6b247

Browse files
author
Allen Wang
committed
more minor cleanups
1 parent efe1806 commit 7d6b247

File tree

1 file changed

+31
-26
lines changed

1 file changed

+31
-26
lines changed

src/forge/controller/service.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
... result = await service.my_endpoint(arg1, arg2)
3333
"""
3434

35-
3635
import asyncio
3736
import contextvars
3837
import logging
@@ -96,43 +95,54 @@ def get_avg_capacity_utilization(self, replicas: List) -> float:
9695
healthy_replicas = [r for r in replicas if r.healthy]
9796
if not healthy_replicas:
9897
return 0.0
99-
10098
total_utilization = sum(r.capacity_utilization for r in healthy_replicas)
10199
return total_utilization / len(healthy_replicas)
102100

103101
def get_sessions_per_replica(self) -> float:
104-
"""Get average sessions per healthy replica."""
105-
if self.healthy_replicas == 0:
102+
"""Get average sessions per replica."""
103+
if self.total_replicas == 0:
106104
return 0.0
107-
return self.total_sessions / self.healthy_replicas
105+
return self.total_sessions / self.total_replicas
106+
107+
108+
# Context variable for session state
109+
_session_context = contextvars.ContextVar("session_context")
108110

109111

110112
@dataclass
111113
class Session:
114+
"""Simple session data holder."""
115+
112116
session_id: str
113117

114118

115-
# Global context variable for session state
116-
# This is used to propagate session state across async tasks
117-
_session_context: contextvars.ContextVar[dict | None] = contextvars.ContextVar(
118-
"session_context", default=None
119-
)
119+
class SessionContext:
120+
"""
121+
Async context manager for stateful service sessions with automatic lifecycle management.
122+
123+
Provides a convenient way to maintain stateful connections to replicas across multiple
124+
requests. Sessions ensure that all requests within the context are routed to the same
125+
replica, enabling stateful interactions while handling session lifecycle automatically.
126+
127+
Example:
120128
129+
>>> async with service.session() as session:
130+
... # All calls within this block use the same replica
131+
... result1 = await service.my_endpoint(arg1)
132+
... result2 = await service.another_endpoint(result1)
121133
122-
class SessionContext:
123-
"""Context manager for service sessions using context variables."""
134+
"""
124135

125-
def __init__(self, service: "Service", **session_kwargs):
136+
def __init__(self, service: "Service"):
126137
self.service = service
127138
self.session_id: str | None = None
128-
self.session_kwargs = session_kwargs
129139
self._token = None
130140

131141
async def __aenter__(self):
132142
"""Start a session and set context variables."""
133143
self.session_id = await self.service.start_session()
134144
# Set context for this async task
135-
context_value = {"session_id": self.session_id, "kwargs": self.session_kwargs}
145+
context_value = {"session_id": self.session_id}
136146
self._token = _session_context.set(context_value)
137147
return self
138148

@@ -228,8 +238,8 @@ async def __initialize__(self):
228238
num_replicas = self._cfg.num_replicas
229239
for i in range(num_replicas):
230240
replica = Replica(
231-
proc_config=self._cfg.to_process_config(),
232241
idx=len(self._replicas) + i,
242+
proc_config=self._cfg.to_process_config(),
233243
max_concurrent_requests=self._cfg.replica_max_concurrent_requests,
234244
return_first_rank_result=self._cfg.return_first_rank_result,
235245
)
@@ -294,13 +304,8 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs):
294304
ctx = _session_context.get()
295305
if ctx:
296306
sess_id = ctx["session_id"]
297-
routing_hints = ctx["kwargs"]
298-
else:
299-
routing_hints = {}
300-
else:
301-
routing_hints = {}
302307

303-
replica = await self._get_replica(sess_id, **routing_hints)
308+
replica = await self._get_replica(sess_id)
304309

305310
# Create a ServiceRequest object to queue
306311
request = ServiceRequest(
@@ -412,9 +417,9 @@ async def start_session(self) -> str:
412417

413418
return sess_id
414419

415-
def session(self, **kwargs) -> SessionContext:
420+
def session(self) -> SessionContext:
416421
"""Returns a context manager for session-based calls."""
417-
return SessionContext(self, **kwargs)
422+
return SessionContext(self)
418423

419424
def _update_service_metrics(self):
420425
"""Updates service-level metrics."""
@@ -582,8 +587,8 @@ def get_load(replica: "Replica") -> int:
582587

583588
return min(healthy_replicas, key=get_load)
584589

585-
async def _get_replica(self, sess_id: str | None, **kwargs) -> "Replica":
586-
"""Get a replica for the given session ID, with optional custom routing hints."""
590+
async def _get_replica(self, sess_id: str | None) -> "Replica":
591+
"""Get a replica for the given session ID."""
587592
if sess_id is None:
588593
# No session, use round-robin load balancing
589594
replica = self._get_next_replica()

0 commit comments

Comments
 (0)