@@ -159,20 +159,19 @@ def __init__(
159159 if self ._scheme == "https" :
160160 self ._ssl_context = ssl .create_default_context ()
161161
162- # HTTP components
163- self ._pool : ConnectionPool | None = None
164- self ._http_template : HttpRequestTemplate | None = None
165- self ._loop : asyncio .AbstractEventLoop | None = None
162+ # HTTP components (initialized in run())
163+ self ._pool : ConnectionPool = None # type: ignore[assignment]
164+ self ._http_template : HttpRequestTemplate = None # type: ignore[assignment]
165+ self ._loop : asyncio .AbstractEventLoop = None # type: ignore[assignment]
166166
167- # IPC transports
168- self ._requests : ReceiverTransport | None = None
169- self ._responses : SenderTransport | None = None
167+ # IPC transports (initialized in run())
168+ self ._requests : ReceiverTransport = None # type: ignore[assignment]
169+ self ._responses : SenderTransport = None # type: ignore[assignment]
170170
171171 # Track active request tasks
172172 self ._active_tasks : set [asyncio .Task ] = set ()
173173
174174 # Use adapter type from config
175- assert self .http_config .adapter is not None
176175 self ._adapter : type [HttpRequestAdapter ] = self .http_config .adapter
177176
178177 async def run (self ) -> None :
@@ -184,7 +183,6 @@ async def run(self) -> None:
184183 # Use eager task factory for immediate coroutine execution
185184 # Tasks start executing synchronously until first await
186185 # NOTE(vir): CRITICAL for minimizing TFB/TTFT
187- assert self ._loop is not None
188186 self ._loop .set_task_factory (asyncio .eager_task_factory ) # type: ignore[arg-type]
189187
190188 # Initialize HTTP template from URL components
@@ -267,7 +265,9 @@ async def run(self) -> None:
267265 if self .http_config .record_worker_events :
268266 pid = os .getpid ()
269267 worker_db_name = f"worker_report_{ self .worker_id } _{ pid } "
270- assert self .http_config .event_logs_dir is not None
268+ assert (
269+ self .http_config .event_logs_dir is not None
270+ ), "event_logs_dir must be set if record_worker_events is enabled"
271271 report_path = self .http_config .event_logs_dir / f"{ worker_db_name } .csv"
272272
273273 with EventRecorder (session_id = worker_db_name ) as event_recorder :
@@ -327,16 +327,13 @@ async def _run_main_loop(self) -> None:
327327 assert_active = True ,
328328 )
329329
330- # Prepare request
331- prepared = self ._prepare_request (query )
332-
333- # Fire request
334- if not await self ._fire_request (prepared ):
330+ # Prepare and fire request
331+ req = self ._prepare_request (query )
332+ if not await self ._fire_request (req ):
335333 continue
336334
337335 # Process response asynchronously
338- assert self ._loop is not None
339- task = self ._loop .create_task (self ._process_response (prepared ))
336+ task = self ._loop .create_task (self ._process_response (req ))
340337
341338 # Keep task alive to prevent GC
342339 # Cleaned up in _process_response finally block
@@ -359,7 +356,6 @@ def _prepare_request(self, query: Query) -> InFlightRequest:
359356 is_streaming = query .data .get ("stream" , False )
360357
361358 # Build complete HTTP request bytes
362- assert self ._http_template is not None
363359 http_bytes = self ._http_template .build_request (
364360 body_bytes ,
365361 is_streaming ,
@@ -381,23 +377,21 @@ async def _fire_request(self, req: InFlightRequest) -> bool:
381377 Fire HTTP POST request:
382378 1. Acquire TCP connection from pool
383379 2. Send POST request bytes
384- 3. Store connection for process_response task
385380
386- Returns True on success.
381+ Returns True on success, False on failure (error response sent) .
387382 """
388383 if self ._shutdown :
389384 await self ._handle_error (req .query_id , "Worker is shutting down" )
390385 return False
391386
392387 try :
393388 # Acquire connection from pool
394- assert self ._pool is not None
395389 conn = await self ._pool .acquire ()
396390
397391 # Write request bytes directly to transport
398392 conn .protocol .write (req .http_bytes )
399393
400- # Store connection for _process_response to use
394+ # Store connection on req for response processing
401395 req .connection = conn
402396
403397 return True
@@ -410,18 +404,14 @@ async def _fire_request(self, req: InFlightRequest) -> bool:
410404 @profile
411405 async def _process_response (self , req : InFlightRequest ) -> None :
412406 """Process response for a fired request."""
413- try :
414- conn = req .connection
415- assert conn is not None , "Connection should be set by _fire_request"
407+ conn = req .connection
416408
409+ try :
417410 # Await headers and handle error status
418411 status_code , _ = await conn .protocol .read_headers ()
419412 if status_code != 200 :
420413 error_body = await conn .protocol .read_body ()
421- # Release connection early - done with socket I/O
422- assert self ._pool is not None
423414 self ._pool .release (conn )
424- req .connection = None
425415 await self ._handle_error (
426416 req .query_id ,
427417 f"HTTP { status_code } : { error_body .decode ('utf-8' , errors = 'replace' )} " ,
@@ -439,11 +429,8 @@ async def _process_response(self, req: InFlightRequest) -> None:
439429 logger .warning (f"Request { req .query_id } failed: { type (e ).__name__ } : { e } " )
440430
441431 finally :
442- # Release connection back to pool if not already released
443- if req .connection :
444- assert self ._pool is not None
445- self ._pool .release (req .connection )
446- req .connection = None
432+ # Release connection back to pool if not already
433+ self ._pool .release (conn )
447434
448435 # Record completion event
449436 if self .http_config .record_worker_events :
@@ -462,18 +449,15 @@ async def _process_response(self, req: InFlightRequest) -> None:
462449 @profile
463450 async def _handle_streaming_body (self , req : InFlightRequest ) -> None :
464451 """Handle streaming (SSE) response body."""
465- conn = req .connection
466- assert conn is not None
467452 query_id = req .query_id
453+ conn = req .connection
468454
469455 # Create accumulator for streaming response
470- assert self .http_config .accumulator is not None
471456 accumulator = self .http_config .accumulator (
472457 query_id , self .http_config .stream_all_chunks
473458 )
474459
475460 # Process SSE stream - yields batches of chunks
476- assert self ._responses is not None
477461 async for chunk_batch in self ._iter_sse_lines (conn ):
478462 for delta in chunk_batch :
479463 if stream_chunk := accumulator .add_chunk (delta ):
@@ -487,10 +471,8 @@ async def _handle_streaming_body(self, req: InFlightRequest) -> None:
487471 assert_active = True ,
488472 )
489473
490- # Release connection early - done with socket I/O
491- assert self ._pool is not None
474+ # Release connection early - done with socket I/O (idempotent)
492475 self ._pool .release (conn )
493- req .connection = None
494476
495477 # Send final complete back to main rank
496478 self ._responses .send (accumulator .get_final_output ())
@@ -505,23 +487,19 @@ async def _handle_streaming_body(self, req: InFlightRequest) -> None:
505487 @profile
506488 async def _handle_non_streaming_body (self , req : InFlightRequest ) -> None :
507489 """Handle non-streaming response body."""
508- conn = req .connection
509- assert conn is not None
510490 query_id = req .query_id
491+ conn = req .connection
511492
512493 # Read entire response body
513494 response_bytes = await conn .protocol .read_body ()
514495
515- # Release connection early - done with socket I/O
516- assert self ._pool is not None
496+ # Release connection early - done with socket I/O (idempotent)
517497 self ._pool .release (conn )
518- req .connection = None
519498
520499 # Decode using adapter
521500 result = self ._adapter .decode_response (response_bytes , query_id )
522501
523502 # Send result back to main rank
524- assert self ._responses is not None
525503 self ._responses .send (result )
526504 if self .http_config .record_worker_events :
527505 EventRecorder .record_event (
0 commit comments