Skip to content

Commit 0ceafb4

Browse files
authored
chore(http-client): cleanup types, improve coverage, remove orjson (#159)
* drop orjson * update * trigger CI
1 parent 760fb88 commit 0ceafb4

File tree

17 files changed

+300
-153
lines changed

17 files changed

+300
-153
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ dependencies = [
3737
"websocket-client==1.9.0",
3838
# Data handling
3939
"duckdb==1.4.0",
40-
"orjson==3.11.5",
4140
"msgspec==0.20.0",
4241
"pydantic==2.12.0",
4342
"pydantic_core==2.41.1",

src/inference_endpoint/endpoint_client/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,13 @@
1515

1616
"""
1717
Endpoint Client for the MLPerf Inference Endpoint Benchmarking System.
18-
19-
This module provides HTTP client implementation with multiprocessing and ZMQ.
18+
This module provides HTTP client implementation.
2019
"""
2120

2221
from .config import HTTPClientConfig
23-
from .http_client import AsyncHttpEndpointClient, HTTPEndpointClient
22+
from .http_client import HTTPEndpointClient
2423

2524
__all__ = [
26-
"AsyncHttpEndpointClient",
2725
"HTTPEndpointClient",
2826
"HTTPClientConfig",
2927
]

src/inference_endpoint/endpoint_client/config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,16 @@ class HTTPClientConfig:
117117
worker_gc_mode: Literal["disabled", "relaxed", "system"] = "relaxed"
118118

119119
# Request adapter for Query/Response <-> Payload/Response bytes
120-
adapter: type[HttpRequestAdapter] | None = None # None: use default
120+
# Default in __post_init__ if None
121+
adapter: type[HttpRequestAdapter] = None # type: ignore[assignment]
121122

122123
# SSE accumulator for streaming responses
123-
accumulator: type[SSEAccumulatorProtocol] | None = None # None: use default
124+
# Default in __post_init__ if None
125+
accumulator: type[SSEAccumulatorProtocol] = None # type: ignore[assignment]
124126

125127
# Worker pool transport class for worker IPC
126-
worker_pool_transport: type[WorkerPoolTransport] | None = None # None: use default
128+
# Default in __post_init__ if None
129+
worker_pool_transport: type[WorkerPoolTransport] = None # type: ignore[assignment]
127130

128131
def __post_init__(self):
129132
# set default adapter in __post_init__ to avoid circular dependency

src/inference_endpoint/endpoint_client/http.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,10 @@ def protocol_factory() -> HttpResponseProtocol:
562562
self._creating -= 1
563563

564564
def release(self, conn: PooledConnection) -> None:
565-
"""Return connection to pool for reuse and notify waiters."""
565+
"""Return connection to pool for reuse and notify waiters (idempotent)."""
566+
if not conn.in_use:
567+
return
568+
566569
# Must close if: dead, server requested close, or error occurred
567570
if not conn.is_alive() or conn.protocol.should_close:
568571
self._close_connection(conn)
@@ -780,10 +783,10 @@ class InFlightRequest:
780783
query_id: Correlates response back to original Query.
781784
http_bytes: Serialized HTTP request for socket.write().
782785
is_streaming: Whether this is a streaming (SSE) request or not.
783-
connection: PooledConnection if any assigned to this request.
786+
connection: PooledConnection assigned to this request (set once request is fired).
784787
"""
785788

786789
query_id: str
787790
http_bytes: bytes
788791
is_streaming: bool
789-
connection: PooledConnection | None = field(default=None, repr=False)
792+
connection: PooledConnection = field(default=None, repr=False) # type: ignore[assignment]

src/inference_endpoint/endpoint_client/http_client.py

Lines changed: 20 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
logger = logging.getLogger(__name__)
3131

3232

33-
class AsyncHttpEndpointClient:
33+
class HTTPEndpointClient:
3434
"""
35-
Async HTTP client for LLM inference.
35+
HTTP client for LLM inference.
3636
3737
Architecture:
3838
- Main process: Accepts requests, distributes to workers, handles responses
@@ -45,10 +45,12 @@ class AsyncHttpEndpointClient:
4545
4646
Usage:
4747
with ManagedZMQContext.scoped() as zmq_ctx:
48-
client = AsyncHttpEndpointClient(config, zmq_context=zmq_ctx)
48+
client = HTTPEndpointClient(config, zmq_context=zmq_ctx)
4949
client.issue(query)
50-
response = await client.recv()
51-
await client.shutdown()
50+
response = client.poll() # Non-blocking, returns None if nothing ready
51+
responses = client.drain() # Drain all available responses
52+
# response = await client.recv() # Blocking; only if caller provides its own loop
53+
client.shutdown() # Blocks until workers stop
5254
"""
5355

5456
def __init__(
@@ -60,6 +62,8 @@ def __init__(
6062
self.client_id = uuid.uuid4().hex[:8]
6163
self.config = config
6264
self._worker_cycle = cycle(range(self.config.num_workers))
65+
66+
# TODO(vir): make context setup/teardown part of transport protocol
6367
if config.worker_pool_transport is ZmqWorkerPoolTransport:
6468
if zmq_context is None:
6569
raise ValueError(
@@ -85,9 +89,6 @@ def __init__(
8589
# Initialize on event loop
8690
asyncio.run_coroutine_threadsafe(self._initialize(), self.loop).result()
8791

88-
assert self.config.adapter is not None
89-
assert self.config.accumulator is not None
90-
assert self.config.worker_pool_transport is not None
9192
logger.info(
9293
f"EndpointClient initialized with num_workers={self.config.num_workers}, "
9394
f"endpoints={self.config.endpoint_urls}, "
@@ -131,16 +132,20 @@ async def recv(self) -> QueryResult | StreamChunk | None:
131132

132133
def drain(self) -> list[QueryResult | StreamChunk]:
133134
"""Non-blocking. Returns all available responses."""
134-
results: list[QueryResult | StreamChunk] = []
135-
while (r := self.poll()) is not None:
136-
results.append(r)
137-
return results
135+
return list(iter(self.poll, None))
138136

139-
async def shutdown(self) -> None:
140-
"""Gracefully shutdown client."""
141-
logger.info(f"[{self.client_id}] Shutting down...")
137+
def shutdown(self) -> None:
138+
"""Gracefully shutdown client. Synchronous — blocks the caller until complete."""
139+
if self._shutdown: # Already shutdown, no-op
140+
return
141+
asyncio.run_coroutine_threadsafe(self._shutdown_async(), self.loop).result()
142+
143+
async def _shutdown_async(self) -> None:
144+
"""Async shutdown internals - must be called on the event loop."""
142145
self._shutdown = True
143146

147+
logger.info(f"[{self.client_id}] Shutting down...")
148+
144149
# Shutdown workers
145150
await self.worker_manager.shutdown()
146151

@@ -154,27 +159,3 @@ async def shutdown(self) -> None:
154159
f"[{self.client_id}] Dropped {self._dropped_requests} requests during shutdown"
155160
)
156161
logger.info(f"[{self.client_id}] Shutdown complete.")
157-
158-
159-
class HTTPEndpointClient(AsyncHttpEndpointClient):
160-
"""
161-
Sync HTTP client for LLM inference.
162-
Inherits from AsyncHttpEndpointClient and provides sync interface.
163-
164-
Usage:
165-
client = HTTPEndpointClient(config)
166-
client.issue(query)
167-
"""
168-
169-
def issue(self, query: Query) -> None: # type: ignore[override]
170-
"""Issue query."""
171-
# Schedule on event loop thread
172-
assert self.loop is not None
173-
self.loop.call_soon_threadsafe(
174-
lambda: super(HTTPEndpointClient, self).issue(query)
175-
)
176-
177-
def shutdown(self) -> None: # type: ignore[override]
178-
"""Sync shutdown wrapper - blocks until base class async shutdown completes."""
179-
assert self.loop is not None
180-
asyncio.run_coroutine_threadsafe(super().shutdown(), self.loop).result()

src/inference_endpoint/endpoint_client/http_sample_issuer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def __init__(
5151
self.http_client = http_client
5252

5353
# Start response handler task to route completed responses back to SampleEventHandler
54-
assert self.http_client.loop is not None
5554
self._response_task = asyncio.run_coroutine_threadsafe(
5655
self._handle_responses(), self.http_client.loop
5756
)

src/inference_endpoint/endpoint_client/worker.py

Lines changed: 24 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -159,20 +159,19 @@ def __init__(
159159
if self._scheme == "https":
160160
self._ssl_context = ssl.create_default_context()
161161

162-
# HTTP components
163-
self._pool: ConnectionPool | None = None
164-
self._http_template: HttpRequestTemplate | None = None
165-
self._loop: asyncio.AbstractEventLoop | None = None
162+
# HTTP components (initialized in run())
163+
self._pool: ConnectionPool = None # type: ignore[assignment]
164+
self._http_template: HttpRequestTemplate = None # type: ignore[assignment]
165+
self._loop: asyncio.AbstractEventLoop = None # type: ignore[assignment]
166166

167-
# IPC transports
168-
self._requests: ReceiverTransport | None = None
169-
self._responses: SenderTransport | None = None
167+
# IPC transports (initialized in run())
168+
self._requests: ReceiverTransport = None # type: ignore[assignment]
169+
self._responses: SenderTransport = None # type: ignore[assignment]
170170

171171
# Track active request tasks
172172
self._active_tasks: set[asyncio.Task] = set()
173173

174174
# Use adapter type from config
175-
assert self.http_config.adapter is not None
176175
self._adapter: type[HttpRequestAdapter] = self.http_config.adapter
177176

178177
async def run(self) -> None:
@@ -184,7 +183,6 @@ async def run(self) -> None:
184183
# Use eager task factory for immediate coroutine execution
185184
# Tasks start executing synchronously until first await
186185
# NOTE(vir): CRITICAL for minimizing TFB/TTFT
187-
assert self._loop is not None
188186
self._loop.set_task_factory(asyncio.eager_task_factory) # type: ignore[arg-type]
189187

190188
# Initialize HTTP template from URL components
@@ -267,7 +265,9 @@ async def run(self) -> None:
267265
if self.http_config.record_worker_events:
268266
pid = os.getpid()
269267
worker_db_name = f"worker_report_{self.worker_id}_{pid}"
270-
assert self.http_config.event_logs_dir is not None
268+
assert (
269+
self.http_config.event_logs_dir is not None
270+
), "event_logs_dir must be set if record_worker_events is enabled"
271271
report_path = self.http_config.event_logs_dir / f"{worker_db_name}.csv"
272272

273273
with EventRecorder(session_id=worker_db_name) as event_recorder:
@@ -327,16 +327,13 @@ async def _run_main_loop(self) -> None:
327327
assert_active=True,
328328
)
329329

330-
# Prepare request
331-
prepared = self._prepare_request(query)
332-
333-
# Fire request
334-
if not await self._fire_request(prepared):
330+
# Prepare and fire request
331+
req = self._prepare_request(query)
332+
if not await self._fire_request(req):
335333
continue
336334

337335
# Process response asynchronously
338-
assert self._loop is not None
339-
task = self._loop.create_task(self._process_response(prepared))
336+
task = self._loop.create_task(self._process_response(req))
340337

341338
# Keep task alive to prevent GC
342339
# Cleaned up in _process_response finally block
@@ -359,7 +356,6 @@ def _prepare_request(self, query: Query) -> InFlightRequest:
359356
is_streaming = query.data.get("stream", False)
360357

361358
# Build complete HTTP request bytes
362-
assert self._http_template is not None
363359
http_bytes = self._http_template.build_request(
364360
body_bytes,
365361
is_streaming,
@@ -381,23 +377,21 @@ async def _fire_request(self, req: InFlightRequest) -> bool:
381377
Fire HTTP POST request:
382378
1. Acquire TCP connection from pool
383379
2. Send POST request bytes
384-
3. Store connection for process_response task
385380
386-
Returns True on success.
381+
Returns True on success, False on failure (error response sent).
387382
"""
388383
if self._shutdown:
389384
await self._handle_error(req.query_id, "Worker is shutting down")
390385
return False
391386

392387
try:
393388
# Acquire connection from pool
394-
assert self._pool is not None
395389
conn = await self._pool.acquire()
396390

397391
# Write request bytes directly to transport
398392
conn.protocol.write(req.http_bytes)
399393

400-
# Store connection for _process_response to use
394+
# Store connection on req for response processing
401395
req.connection = conn
402396

403397
return True
@@ -410,18 +404,14 @@ async def _fire_request(self, req: InFlightRequest) -> bool:
410404
@profile
411405
async def _process_response(self, req: InFlightRequest) -> None:
412406
"""Process response for a fired request."""
413-
try:
414-
conn = req.connection
415-
assert conn is not None, "Connection should be set by _fire_request"
407+
conn = req.connection
416408

409+
try:
417410
# Await headers and handle error status
418411
status_code, _ = await conn.protocol.read_headers()
419412
if status_code != 200:
420413
error_body = await conn.protocol.read_body()
421-
# Release connection early - done with socket I/O
422-
assert self._pool is not None
423414
self._pool.release(conn)
424-
req.connection = None
425415
await self._handle_error(
426416
req.query_id,
427417
f"HTTP {status_code}: {error_body.decode('utf-8', errors='replace')}",
@@ -439,11 +429,8 @@ async def _process_response(self, req: InFlightRequest) -> None:
439429
logger.warning(f"Request {req.query_id} failed: {type(e).__name__}: {e}")
440430

441431
finally:
442-
# Release connection back to pool if not already released
443-
if req.connection:
444-
assert self._pool is not None
445-
self._pool.release(req.connection)
446-
req.connection = None
432+
# Release connection back to pool if not already
433+
self._pool.release(conn)
447434

448435
# Record completion event
449436
if self.http_config.record_worker_events:
@@ -462,18 +449,15 @@ async def _process_response(self, req: InFlightRequest) -> None:
462449
@profile
463450
async def _handle_streaming_body(self, req: InFlightRequest) -> None:
464451
"""Handle streaming (SSE) response body."""
465-
conn = req.connection
466-
assert conn is not None
467452
query_id = req.query_id
453+
conn = req.connection
468454

469455
# Create accumulator for streaming response
470-
assert self.http_config.accumulator is not None
471456
accumulator = self.http_config.accumulator(
472457
query_id, self.http_config.stream_all_chunks
473458
)
474459

475460
# Process SSE stream - yields batches of chunks
476-
assert self._responses is not None
477461
async for chunk_batch in self._iter_sse_lines(conn):
478462
for delta in chunk_batch:
479463
if stream_chunk := accumulator.add_chunk(delta):
@@ -487,10 +471,8 @@ async def _handle_streaming_body(self, req: InFlightRequest) -> None:
487471
assert_active=True,
488472
)
489473

490-
# Release connection early - done with socket I/O
491-
assert self._pool is not None
474+
# Release connection early - done with socket I/O (idempotent)
492475
self._pool.release(conn)
493-
req.connection = None
494476

495477
# Send final complete back to main rank
496478
self._responses.send(accumulator.get_final_output())
@@ -505,23 +487,19 @@ async def _handle_streaming_body(self, req: InFlightRequest) -> None:
505487
@profile
506488
async def _handle_non_streaming_body(self, req: InFlightRequest) -> None:
507489
"""Handle non-streaming response body."""
508-
conn = req.connection
509-
assert conn is not None
510490
query_id = req.query_id
491+
conn = req.connection
511492

512493
# Read entire response body
513494
response_bytes = await conn.protocol.read_body()
514495

515-
# Release connection early - done with socket I/O
516-
assert self._pool is not None
496+
# Release connection early - done with socket I/O (idempotent)
517497
self._pool.release(conn)
518-
req.connection = None
519498

520499
# Decode using adapter
521500
result = self._adapter.decode_response(response_bytes, query_id)
522501

523502
# Send result back to main rank
524-
assert self._responses is not None
525503
self._responses.send(result)
526504
if self.http_config.record_worker_events:
527505
EventRecorder.record_event(

0 commit comments

Comments
 (0)