|
32 | 32 | ... result = await service.my_endpoint(arg1, arg2) |
33 | 33 | """ |
34 | 34 |
|
35 | | - |
36 | 35 | import asyncio |
37 | 36 | import contextvars |
38 | 37 | import logging |
@@ -96,43 +95,54 @@ def get_avg_capacity_utilization(self, replicas: List) -> float: |
96 | 95 | healthy_replicas = [r for r in replicas if r.healthy] |
97 | 96 | if not healthy_replicas: |
98 | 97 | return 0.0 |
99 | | - |
100 | 98 | total_utilization = sum(r.capacity_utilization for r in healthy_replicas) |
101 | 99 | return total_utilization / len(healthy_replicas) |
102 | 100 |
|
103 | 101 | def get_sessions_per_replica(self) -> float: |
104 | | - """Get average sessions per healthy replica.""" |
105 | | - if self.healthy_replicas == 0: |
| 102 | + """Get average sessions per replica.""" |
| 103 | + if self.total_replicas == 0: |
106 | 104 | return 0.0 |
107 | | - return self.total_sessions / self.healthy_replicas |
| 105 | + return self.total_sessions / self.total_replicas |
| 106 | + |
| 107 | + |
| 108 | +# Context variable for session state |
| 109 | +_session_context = contextvars.ContextVar("session_context") |
108 | 110 |
|
109 | 111 |
|
110 | 112 | @dataclass |
111 | 113 | class Session: |
| 114 | + """Simple session data holder.""" |
| 115 | + |
112 | 116 | session_id: str |
113 | 117 |
|
114 | 118 |
|
115 | | -# Global context variable for session state |
116 | | -# This is used to propagate session state across async tasks |
117 | | -_session_context: contextvars.ContextVar[dict | None] = contextvars.ContextVar( |
118 | | - "session_context", default=None |
119 | | -) |
| 119 | +class SessionContext: |
| 120 | + """ |
| 121 | + Async context manager for stateful service sessions with automatic lifecycle management. |
| 122 | +
|
| 123 | + Provides a convenient way to maintain stateful connections to replicas across multiple |
| 124 | + requests. Sessions ensure that all requests within the context are routed to the same |
| 125 | + replica, enabling stateful interactions while handling session lifecycle automatically. |
| 126 | +
|
| 127 | + Example: |
120 | 128 |
|
| 129 | + >>> async with service.session() as session: |
| 130 | + ... # All calls within this block use the same replica |
| 131 | + ... result1 = await service.my_endpoint(arg1) |
| 132 | + ... result2 = await service.another_endpoint(result1) |
121 | 133 |
|
122 | | -class SessionContext: |
123 | | - """Context manager for service sessions using context variables.""" |
| 134 | + """ |
124 | 135 |
|
125 | | - def __init__(self, service: "Service", **session_kwargs): |
| 136 | + def __init__(self, service: "Service"): |
126 | 137 | self.service = service |
127 | 138 | self.session_id: str | None = None |
128 | | - self.session_kwargs = session_kwargs |
129 | 139 | self._token = None |
130 | 140 |
|
131 | 141 | async def __aenter__(self): |
132 | 142 | """Start a session and set context variables.""" |
133 | 143 | self.session_id = await self.service.start_session() |
134 | 144 | # Set context for this async task |
135 | | - context_value = {"session_id": self.session_id, "kwargs": self.session_kwargs} |
| 145 | + context_value = {"session_id": self.session_id} |
136 | 146 | self._token = _session_context.set(context_value) |
137 | 147 | return self |
138 | 148 |
|
@@ -228,8 +238,8 @@ async def __initialize__(self): |
228 | 238 | num_replicas = self._cfg.num_replicas |
229 | 239 | for i in range(num_replicas): |
230 | 240 | replica = Replica( |
231 | | - proc_config=self._cfg.to_process_config(), |
232 | 241 | idx=len(self._replicas) + i, |
| 242 | + proc_config=self._cfg.to_process_config(), |
233 | 243 | max_concurrent_requests=self._cfg.replica_max_concurrent_requests, |
234 | 244 | return_first_rank_result=self._cfg.return_first_rank_result, |
235 | 245 | ) |
@@ -294,13 +304,8 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): |
294 | 304 | ctx = _session_context.get() |
295 | 305 | if ctx: |
296 | 306 | sess_id = ctx["session_id"] |
297 | | - routing_hints = ctx["kwargs"] |
298 | | - else: |
299 | | - routing_hints = {} |
300 | | - else: |
301 | | - routing_hints = {} |
302 | 307 |
|
303 | | - replica = await self._get_replica(sess_id, **routing_hints) |
| 308 | + replica = await self._get_replica(sess_id) |
304 | 309 |
|
305 | 310 | # Create a ServiceRequest object to queue |
306 | 311 | request = ServiceRequest( |
@@ -412,9 +417,9 @@ async def start_session(self) -> str: |
412 | 417 |
|
413 | 418 | return sess_id |
414 | 419 |
|
415 | | - def session(self, **kwargs) -> SessionContext: |
| 420 | + def session(self) -> SessionContext: |
416 | 421 | """Returns a context manager for session-based calls.""" |
417 | | - return SessionContext(self, **kwargs) |
| 422 | + return SessionContext(self) |
418 | 423 |
|
419 | 424 | def _update_service_metrics(self): |
420 | 425 | """Updates service-level metrics.""" |
@@ -582,8 +587,8 @@ def get_load(replica: "Replica") -> int: |
582 | 587 |
|
583 | 588 | return min(healthy_replicas, key=get_load) |
584 | 589 |
|
585 | | - async def _get_replica(self, sess_id: str | None, **kwargs) -> "Replica": |
586 | | - """Get a replica for the given session ID, with optional custom routing hints.""" |
| 590 | + async def _get_replica(self, sess_id: str | None) -> "Replica": |
| 591 | + """Get a replica for the given session ID.""" |
587 | 592 | if sess_id is None: |
588 | 593 | # No session, use round-robin load balancing |
589 | 594 | replica = self._get_next_replica() |
|
0 commit comments